Tensorflow分布式MirroredStrategy简介
最近由于一直在使用tensorflow多卡训练,遇到一些问题,于是查看了一些关于estimator关于多卡分布式策略的代码,主要了解了关于MirroredStrategy的相关内容。tf.estimator.Estimator初始化时可以在config中train_distribute设置相应的分布式策略,今天主要记录train_distributtf.contrib.distribute.MirroredStrategy(num_gpus=num_gpus)镜像策略。Estimator中分布式训练由_train_model_distributed(self, input_fn, hooks, saving_listeners)函数执行。
MirroredStrategy主要位于mirrored_strategy文件中,mirrorstrategy对于变量的处理详见create_variable函数中。对于第一个设备,采用原来的名字。对于大于0的设备,在原来变量名后加上/replica_加上设备号,以区别原始变量,并将原来的变量值复制给这些对应的复制变量。
对于输入数据集,主要通过调用distribute_dataset实现,每个设备各自取一份数据,所以是数据并行。
调用model_fn时,主要通过mirroredstrategy中的_call_for_each_tower实现。每个设备各起一个线程,并行执行model_fn,直至所有model_fn都完成。
最终将每个model_fn得到的loss合并并求平均值,然后分发到每个训练操作中,进行对应卡的变量的梯度更新。