tensorflow保存和加载模型

1、tf保存模型

tf.summary.scalar('accuracy',acc)                  
merge_summary = tf.summary.merge_all()  
train_writer = tf.summary.FileWriter(dir,sess.graph)
......(交叉熵、优化器等定义)  
saver = tf.train.Saver()
for step in xrange(training_step):       
    if step%1000==0:           
        saver.save(sess,checkpoint_dir,global_step=step)
        train_summary = sess.run(merge_summary,feed_dict =  {...})
        train_writer.add_summary(train_summary,step)

2、tf保存之后的模型

tensorflow保存和加载模型

主要是三个文件,一个是.data文件(网络的权值,偏置,操作),一个是.index文件(“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等)和.meta文件(图结构) 。我们主要看一下checkpoint文件,打开如下:

tensorflow保存和加载模型

可以看到保存的都是路径名,看到第一行默认保存的是最新的模型路径。

3、tf模型的加载

def checkpoint_load(path):
    print('Reading Checkpoints... .. .\n')
    ckpt = tf.train.get_checkpoint_state(path)
    print(ckpt)

print如下:

model_checkpoint_path: "./model/mnist_model-49001"
all_model_checkpoint_paths: "./model/mnist_model-45001"
all_model_checkpoint_paths: "./model/mnist_model-46001"
all_model_checkpoint_paths: "./model/mnist_model-47001"
all_model_checkpoint_paths: "./model/mnist_model-48001"
all_model_checkpoint_paths: "./model/mnist_model-49001"

所以可以看到tf.train.get_checkpoint_state(path)返回两个结果分别是:

ckpt.model_checkpoint_path
ckpt.all_model_checkpoint_paths

一般使用断点续训的时候我们只需要判断ckpt.model_checkpoint_path加载最新的模型即可:

   if ckpt and ckpt.model_checkpoint_path:
      ckpt_path = str(ckpt.model_checkpoint_path)
      self.saver.restore(self.sess, os.path.join(os.getcwd(), ckpt_path))
      step = int(os.path.basename(ckpt_path).split('-')[1])
      print("\nCheckpoint Loading Success! %s\n" % ckpt_path)