TensorFlow:实战Google深度学习框架(三)模型保存和持久化
-
tensorflow提供API可以实现神经网络的保存和还原 tf.train.Saver()类
-
-
模型的加载
-
模型的加载还可以用另外一种方式实现
-
加载部分变量,把类tf.train.Saver()变成tf.train.Saver([v1]),就可以实现只加载v1变量
-
如果用类tf.train.Saver(),而不是指定要加载的变量时,就要保证需要加载的变量都在模型中,否则模型会报错
-
保存和加载时可以给变量重命名
-
变量重命名的目的之一是为了方便使用变量的滑动平均值(由于滑动平均之可以使神经网路模型更加鲁棒);tensorflow中的滑动平均值是通过影子变量进行维护的,所以要获取变量的滑动值实际上就是获取这个影子变量的值,那么在使用训练好的模型时不再需要调用函数获取变量的滑动平均值,只需要在加载时将影子变量的值映射到变量本身即可
-
示例:
-
为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variable_to_restore函数来生成tf,train.Saver()类所需要的变量重命名字典。
-
-
将训练好的计算图中的变量及其取值通过常量的方式保存
-
,这样可以在加载模型时,只要指定要加载的张量的名称就可以将对应的张量的值加载进来,不用再进行重新的训练
-
对应1中计算图的加载
-
做一个接口用以输入数据
-
-
持久化的原理及数据格式
-
json文件中保存的模型信息分析
-
meta_info_def属性
-
graph_def属性
-
saver_def属性
-
collection_def属性
-