opencv3.4.1 调用tensorflow生成的网络模型数据pb文件
python代码如下是训练手写字的网络
# -*- coding: utf-8 -*-
"""
Created on Fri May 18 16:04:06 2018
@author: Administrator
"""
import pylab
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util
tf.reset_default_graph()
learning_rate = 0.001
training_epochs = 25
batch_size = 100
display_step =1
n_hidden_1 = 256
n_hidden_2 = 256
n_input = 784
n_classes = 10
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
print ('输入数据:', mnist.train.images)
print ('输入数据的shape:', mnist.train.images.shape)
x = tf.placeholder("float", [None, n_input], name = 'input')
y = tf.placeholder("float", [None, n_classes], name = 'labels')
def multilayer_perception(x, weights_t, biases_t):
#第一层隐藏层
layer_1 = tf.add(tf.matmul(x, weights_t['h1']), biases_t['b1'])
layer_1 = tf.nn.relu(layer_1)
#第二层隐藏层
layer_2 = tf.add(tf.matmul(layer_1, weights_t['h2']), biases_t['b2'])
layer_2 = tf.nn.relu(layer_2)
#out_layer = tf.matmul(layer_2, weights_t['out']) + biases_t['out']
out_layer = tf.add( tf.matmul(layer_2, weights_t['out']), biases_t['out'], name = "output")
return out_layer
weights = {
'h1' : tf.Variable(tf.random_normal([n_input, n_hidden_1])),
'h2' : tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))
}
biases = {
'b1' : tf.Variable(tf.random_normal([n_hidden_1])),
'b2' : tf.Variable(tf.random_normal([n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_classes]))
}
#weights_c = {
# 'h1' : tf.constant(tf.random_normal([n_input, n_hidden_1])),
# 'h2' : tf.constant(tf.random_normal([n_hidden_1, n_hidden_2])),
# 'out': tf.constant(tf.random_normal([n_hidden_2, n_classes]))
# }
#
#biases_c = {
# 'b1' : tf.constant(tf.random_normal([n_hidden_1])),
# 'b2' : tf.constant(tf.random_normal([n_hidden_2])),
# 'out': tf.constant(tf.random_normal([n_classes]))
# }
print("learn_param")
pred = multilayer_perception(x, weights, biases)
print("multilayer_perception")
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
saver = tf.train.Saver()
savedir = "log1/"
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict = {x:batch_xs, y:batch_ys})
avg_cost += c/total_batch
if (epoch + 1) % display_step == 0:
print ("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
print("finished")
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile('expert-graph_t1.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
#这几句代码就是用来将图参数变成constant,可以很方便的导出所有的网络参数,不用每一个参数去将其变换成constant.
#im = mnist.train.images[1]#im = im.reshape(-1, 28)
#pylab.imshow(im)
#pylab.show()
#测试模型
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuray = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print ("Accuracy:", accuray.eval({x:mnist.test.images, y:mnist.test.labels}))
save_path = saver.save(sess, savedir + "mnistmodel.cpk")
print ("Model saved in file:%s" % save_path)
# weights_c = {}
# biases_c = {}
# _h1 = weights['h1'].eval(sess)
# _h2 = weights['h2'].eval(sess)
# _out_w = weights['out'].eval(sess)
#
# _b1 = biases['b1'].eval(sess)
# _b2 = biases['b2'].eval(sess)
# _out_b = biases['out'].eval(sess)
#
###pred_c = tf.Constant(pred, name = "pred_out")
#g_2 = tf.Graph()
#with g_2.as_default():
# weights_c['h1'] = tf.constant(_h1 , name = "w_h1")
# weights_c['h2'] = tf.constant(_h2, name = "w_h2")
# weights_c['out'] = tf.constant(_out_w, name = "w_out")
#
# biases_c['b1'] = tf.constant(_b1, name = "w_b1")
# biases_c['b2'] = tf.constant(_b2, name = "w_b2")
# biases_c['out'] = tf.constant(_out_b, name = "b_bout")
# X_2 = tf.placeholder("float", shape = [None, 784], name = "input")
# pred = multilayer_perception(X_2, weights_c, biases_c)
# print ("pred")
# pred_out = tf.nn.softmax(pred, name = "pred_out")
## sess_2 = tf.Session()
## init_2 = tf.initialize_all_variables()
## sess_2.run(init_2)
# #graph_def = g_2.as_graph_def()
# #tf.train.write_graph(graph_def, savedir, 'expert-graph.pb', as_text=False)
# graph_def = g_2.as_graph_def()
# #tf.train.write_graph(graph_def, export_dir, 'expert-graph_t1.pb', as_text=True)
# tf.train.write_graph(graph_def, savedir, 'expert-graph.pb', as_text=False)
## print ("pred1")
## constant_graph = g_2.as_graph_def()
## with tf.gfile.FastGFile(savedir+'expert-graph.pb', mode='wb') as f:
## f.write(constant_graph.SerializeToString())
#
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, savedir + "mnistmodel.cpk")
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuray = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print ("Accuracy:", accuray.eval({x:mnist.test.images, y:mnist.test.labels}))
output = tf.argmax(pred, 1)
batch_xs, batch_ys = mnist.train.next_batch(2)
outval, preadv = sess.run([output, pred], feed_dict = {x:batch_xs, y:batch_ys})
print(outval, preadv, batch_ys)
im = batch_xs[0]
im = im.reshape(-1, 28)
pylab.imshow(im)
pylab.show()
im = batch_xs[1]
im = im.reshape(-1, 28)
print (im.shape)
pylab.imshow(im)
pylab.imsave(savedir + "mnistimg.bmp", im)
pylab.show()
这里测试识别的图片为
在opencv的代码如下:
cv::String weights = "expert-graph_t1.pb";
cv::dnn::Net net = cv::dnn::readNetFromTensorflow(weights); //读取pb文件得到对应的网络模型
cv::Mat img = cv::imread("mnistimg.bmp", 0); //读取图片
cv::Mat inputBlob = cv::dnn::blobFromImage(img, 1.0/255.0, cv::Size(28, 28), cv::Scalar(), false, false); //将图像数据转换成blob数据(tensorflow网络中所用的数据类型,包括减去灰度均值, 并且转换成float型, 将灰度值归一化到某一范围内)
net.setInput(inputBlob); //将blob数据作为输入放到网络中
cv::Mat pred = net.forward(); //网络前向传播
Net Outputs(1):
MatMul_2
257.719 -255 239.082 1127.92 -400.36 765.277 -326.79 -310.638 245.007 134.649
这里的输出没有进行softmax处理,最大的那个就是索引为 3,刚好和python输出的测试结果是一致的