tensorflow学习:模型的保存与恢复(saver)
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。
模型保存,先要创建一个Saver对象:如
saver=tf.train.Saver()
在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:
saver=tf.train.Saver(max_to_keep=0)
当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即:
saver=tf.train.Saver(max_to_keep=1)
还有一个参数keep_checkpoint_every_n_hours,每几小时保存一次模型,如:
saver = tf.train.Saver( keep_checkpoint_every_n_hours=2)
创建完saver对象后,就可以保存训练好的模型了,如:
saver.save(sess,'save_model/model.ckpt',global_step=step)
第一个参数sess,就是tensorflow中的session
第二个参数设定保存的路径和名字,执行代码之后,会在改文件目录下新建一个save_model文件夹用于保存模型,save_model里面保存训练的模型,模型名称为mnist.ckpt,具体有四类文件,checkpoint、ckpt.data、ckpt.index、ckpt.meta,如下图:这四类文件下面再讲。
第三个参数将训练的次数作为后缀加入到模型名字中。例如:
saver.save(sess,'save_model/model.ckpt',global_step=1000)
# filename: save_model/model.ckpt-1000
四类文件
-
meta文件
model.ckpt-200.meta文件保存的是图结构,通俗地讲就是神经网络的网络结构。一般而言网络结构是不会发生改变,所以可以只保存一个就行了。我们可以使用下面的代码只在第一次保存meta文件。
saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)
并且还可以使用tf.train.import_meta_graph(‘model.ckpt-200.meta’)能够导入图结构。
-
data文件
数据文件,保存的是网络的权值,偏置,操作等等 -
index文件
model.ckpt-200.index是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。 -
checkpoint
checkpoint文件内容如下图:
注意:Tensorflow默认只保存最近5个模型和元数据,删除前面没用的模型和元数据。其中,model_checkpoint_path里面保存着当前模型,all_model_checkpoint_paths里面保存着最近的5个模型。
注意:在tensorflow中使用断点时,要修改的地方:checkpoint文件,将它修改为上一次训练好的数值;学习率。
模型恢复
模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:
saver = tf.train.Saver()
model_file=tf.train.latest_checkpoint('checkpoint_dir/') #checkpoint_dir为保存模型的文件夹
saver.restore(sess,model_file)
也可以使用tf.train.get_checkpoint_state来加载模型。这种方法会读取checkpoint文件里面的model_checkpoint_path,然后加载相应的模型:
saver = tf.train.Saver()
MODEL_SAVE_PATH='save_model' # 模型路径
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
参考:
1.https://cloud.tencent.com/developer/news/393046
2.https://blog.****.net/changeforeve/article/details/80268522
3.https://www.cnblogs.com/hellcat/p/6925757.html
4.https://www.cnblogs.com/qwangxiao/p/9036493.html
5.https://www.cnblogs.com/denny402/p/6940134.html