从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)
1 解析CKPT网络结构
解析CKPT
网络结构的第一步是读取CKPT
模型中的图文件,得到图的Graph
对象后即可得到完整的网络结构。读取图文件示例代码如下所示。
saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
graph = tf.get_default_graph()
with tf.Session( graph=graph) as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess,ckpt_path)
- 1
- 2
- 3
- 4
- 5
调用graph.get_operations()
后即可得到当前图的所有计算节点,在利用Operation
对象与Tensor
对象之间的相互引用关系即可推断网络结构。但是需要注意的是,从meta
文件中导入的图中获取计算节点存在如下问题。
- 包含反向梯度下降计算的所有节点
- 某些计算节点是按基础计算(加减乘除等)节点拆分成多个计算节点的,如
BatchNorm
,但其实是可以直接合并成一个节点的。
pb
模型文件可以避免上面第一个问题,将CKPT
模型转pb
模型后,可以自动将反向梯度下降相关计算节点移除。对于第二点,pb
模型文件会自动将基础计算组成一个计算节点,但是对于Tensor操作的函数如Slice等函数是无法合并的。因此,对于第2个问题,将CKPT
模型转pb
模型后,可以减少这类问题,但是无法避免。彻底避免的方法只能通过自己针对性地实现。经过以上分析,得出的结论是非常有必要将CKPT
模型转pb
模型。
2 自动将CKPT转pb,并提取网络图中节点
如果将CKPT自动转pb模型,那么就可以复用上一篇文章《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》的代码。示例代码如下所示。
def read_graph_from_ckpt(ckpt_path,input_names,output_name ): saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True) graph = tf.get_default_graph() with tf.Session( graph=graph) as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess,ckpt_path) output_tf =graph.get_tensor_by_name(output_name) pb_graph = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [output_tf.op.name])
<span class="token keyword">with</span> tf<span class="token punctuation">.</span>Graph<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_default<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">as</span> g<span class="token punctuation">:</span> tf<span class="token punctuation">.</span>import_graph_def<span class="token punctuation">(</span>pb_graph<span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">''</span><span class="token punctuation">)</span> <span class="token keyword">with</span> tf<span class="token punctuation">.</span>Session<span class="token punctuation">(</span>graph<span class="token operator">=</span>g<span class="token punctuation">)</span> <span class="token keyword">as</span> sess<span class="token punctuation">:</span> OPS<span class="token operator">=</span>get_ops_from_pb<span class="token punctuation">(</span>g<span class="token punctuation">,</span>input_names<span class="token punctuation">,</span>output_name<span class="token punctuation">)</span> <span class="token keyword">return</span> OPS
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
其中函数get_ops_from_pb
在上一篇文章《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》中已经实现。
3 测试
以《MobileNet V1官方预训练模型的使用》文中介绍的MobileNet V1网络结构为例,下载MobileNet_v1_1.0_192文件并压缩后,得到mobilenet_v1_1.0_192.ckpt.data-00000-of-00001
、mobilenet_v1_1.0_192.ckpt.index
、mobilenet_v1_1.0_192.ckpt.meta
文件。我们还需要知道mobilenet_v1_1.0_192.ckpt
模型对应的输入和输出Tensor
对象的名称,官方提供的压缩包文件中并没有告知。一种方法是运行官方代码,把输入Tensor的名称打印出来。但是运行官方代码本身就需要一定的时间和精力,在在上一篇文章《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》的代码实现中已经实现了将原始网络结构对应的字符串写入到ori_network.txt
文件中。因此,可以先随意填写输入名称和输出名称,待生成ori_network.txt
文件后,从文件中可以直观看到原始网络结构。ori_network.txt
文件部分内容如下所示。
通过该文件可知,输入Tensor
的名称为:batch:0
,输出Tensor
名称为:MobilenetV1/Predictions/Reshape_1:0
。有了这些信息后,调用函数read_graph_from_ckpt
得到静态图的节点列表对象ops
,调用函数gen_graph(ops,"save/path/graph.html")
后,在目录save/path
中得到graph.html
文件,打开graph.html
后,显示结果如下。
4 源码地址
https://github.com/huachao1001/CNNGraph
</div>
<link href="https://****img.cn/release/phoenix/mdeditor/markdown_views-7b4cdcb592.css" rel="stylesheet">
</div>