tensorflow保存和加载模型
1、tf保存模型
tf.summary.scalar('accuracy',acc)
merge_summary = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(dir,sess.graph)
......(交叉熵、优化器等定义)
saver = tf.train.Saver()
for step in xrange(training_step):
if step%1000==0:
saver.save(sess,checkpoint_dir,global_step=step)
train_summary = sess.run(merge_summary,feed_dict = {...})
train_writer.add_summary(train_summary,step)
2、tf保存之后的模型
主要是三个文件,一个是.data文件(网络的权值,偏置,操作),一个是.index文件(“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等)和.meta文件(图结构) 。我们主要看一下checkpoint文件,打开如下:
可以看到保存的都是路径名,看到第一行默认保存的是最新的模型路径。
3、tf模型的加载
def checkpoint_load(path):
print('Reading Checkpoints... .. .\n')
ckpt = tf.train.get_checkpoint_state(path)
print(ckpt)
print如下:
model_checkpoint_path: "./model/mnist_model-49001"
all_model_checkpoint_paths: "./model/mnist_model-45001"
all_model_checkpoint_paths: "./model/mnist_model-46001"
all_model_checkpoint_paths: "./model/mnist_model-47001"
all_model_checkpoint_paths: "./model/mnist_model-48001"
all_model_checkpoint_paths: "./model/mnist_model-49001"
所以可以看到tf.train.get_checkpoint_state(path)返回两个结果分别是:
ckpt.model_checkpoint_path
ckpt.all_model_checkpoint_paths
一般使用断点续训的时候我们只需要判断ckpt.model_checkpoint_path加载最新的模型即可:
if ckpt and ckpt.model_checkpoint_path:
ckpt_path = str(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(os.getcwd(), ckpt_path))
step = int(os.path.basename(ckpt_path).split('-')[1])
print("\nCheckpoint Loading Success! %s\n" % ckpt_path)