Python TensorFlow,分布式TensorFlow
分布式Tensorflow是由高性能的gRPC框架作为底层技术来支持的。这是一个通信框架gRPC(google remote procedure call),是一个高性能、跨平台的RPC框架。RPC协议,即远程过程调用协议,是指通过网络从远程计算机程序上请求服务。
demo.py(分布式TensorFlow,集群中的所有服务器都需要拷贝该代码并运行):
import tensorflow as tf
# 命令行参数
tf.app.flags.DEFINE_string("job_name", " ", "启动服务的类型 ps 或 worker")
tf.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker当中的哪一台服务器,task:0 ,task:1")
FLAGS = tf.app.flags.FLAGS # 用于获取命令行参数
# 定义全局计数的op(记录训练的步数),配合StopAtStepHook钩子使用
global_step = tf.contrib.framework.get_or_create_global_step()
# 指定集群描述对象, ps , worker (端口号随便指定一个不冲突的端口)
cluster = tf.train.ClusterSpec({"ps": ["10.211.55.3:2223"], "worker": ["192.168.65.44:2222"]})
# 创建不同的服务, ps, worker
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
# 根据不同服务做不同的事情 ps:更新保存参数 worker:指定设备去运行模型计算
if FLAGS.job_name == "ps":
# 参数服务器什么都不用干,只需要等待worker传递参数
server.join()
else: # 否则是worker服务
worker_device = "/job:worker/task:0/cpu:0/" # worker服务器中的第一台服务器中的第一个cpu
# 可以指定设备(GPU/CPU)和该设备的计算内容
# with tf.device("/job:worker/task:0/cpu:0/"): # 不是分布式
with tf.device(tf.train.replica_device_setter(
worker_device=worker_device, # 指定设备(哪个CPU或GPU)。(只能指定当前电脑中的设备)
cluster=cluster # 分布式
)):
# 该设备的计算内容: 简单做一个矩阵乘法运算
x = tf.Variable([[1, 2, 3, 4]]) # Variable变量是模型参数,会保存到ps服务器中,每次训练迭代都会更新ps中的参数。
w = tf.Variable([[2], [2], [2], [2]])
mat = tf.matmul(x, w)
# 创建分布式会话
with tf.train.MonitoredTrainingSession(
master= "grpc://192.168.65.44:2222", # 指定主worker (所有worker共用主worker创建的会话)
is_chief= (FLAGS.task_index == 0), # 判断是否是主worker
config=tf.ConfigProto(log_device_placement=True), # 打印运算设备(CPU/GPU)信息
hooks=[tf.train.StopAtStepHook(last_step=200)] # 钩子列表。 StopAtStepHook可以指定op的训练次数(迭代次数)
) as mon_sess:
while not mon_sess.should_stop():
print(mon_sess.run(mat))