horovod + tf.train.CheckpointSaverHook
最近在做分布式模型训练相关工作,利用到了horovod框架,当horovod+tf.train.MonitoredTrainingSession+tf.train.CheckpointSaverHook时,会出现horovod rank抢占之类的报错。并且在log中多次出现Create CheckpointSaverHook的信息。
并且由于MonitoredTrainingSession的重启机制,session重复start,报错也一直重复。
相关代码如下:
...
checkpoint_path = ... if hvd.rank() == 0 else None
hook_worker = [hvd.BroadcastGlobalVariablesHook(0)]
hook_master = [
tf.train.CheckpointSaverHook(checkpoint_dir,
save_steps,
saver,
checkpoint_basename='model.ckpt')
]
hooks = hook_worker + hook_master if hvd.rank()==0 else hook_worker
with tf.train.MonitoredTrainingSession(checkpoint_dir = None,
config = config,
hooks = hooks
) as mon_sess:
...
之后将这一部分代码改为:
hook_worker = [
hvd.BroadcastGlobalVariablesHook(0)
]
if hvd.rank()==0:
hook_master = [
tf.train.CheckpointSaverHook(...),
tf.train.SummarySaverHook(...)
]
else:
hook_master = []
hooks = hook_worker + hook_master
这一部分的错误可能和CheckpointSaverHook源码有关,需要进一步研究源码。