持久化的MNIST手写字识别,实现的那个MLP网络模型可以在MNIST数据集上得到98%的正确率

在前面的内容中我们写过一个完整的Tensorflow程序去解决MNSIT手写字识别问题,实现的那个MLP网络模型可以在MNIST数据集上得到98%的正确率,虽然准确率是很高但是也出现了一个问题,那就是每一次使用网络就需要重新训练一次,这样就造成了时间的浪费。尤其是在大型的神经网络中,模型训练的时间会变得更长,甚至会花上几天几周的时间。每次重新训练这样的大型神经网络显然是不行的,因此在训练的时候保存模型是非常有必要的。结合Tensorlfow模型持久化和前面所介绍的变量空间的知识,在本小节我们会重构MLP来实现一个更完善的MNIST手写字识别样例。重构之后并没有太大变化,只是将训练和测试分为两个独立的程序—训练神经网络的程序和预估网络模型正确率的程序。训练神经网络的程序可以输出训练好的模型,而评估程序可以检验最新模型的正确率,如果模型效果表现地和训练时一样出色,那么这个模型就可以算作成功了。重构之后的代码分为两个程序:一个是mnist_train.py,它定义了前向传播的过程以及神经网络中的参数并完成了神经网络的训练过程;另一个是mnist_evaluate.py,它定义了验证和测试的过程,我们先给出mnist_train.py的代码内容:

持久化的MNIST手写字识别,实现的那个MLP网络模型可以在MNIST数据集上得到98%的正确率
这是代码的上半部分,上部分代码是定义一个隐藏层,和训练的bacth、学习率等。下半部分代码是求loss函数,以及输出结果。

持久化的MNIST手写字识别,实现的那个MLP网络模型可以在MNIST数据集上得到98%的正确率
在新的训练代码中,没有输出在验证集和测试集上的正确率,而是在训练过程中每隔1000轮输出一次在当前训练bath上的损失函数的大小并且保存一次训练好的模型。下面的输出展示了在训练的前5000轮输出的损失以及在训练的后5000轮输出的损失:
持久化的MNIST手写字识别,实现的那个MLP网络模型可以在MNIST数据集上得到98%的正确率

根据代码中的目录打开存储这些模型的文件夹,会发现理论上该存储的30个模型上实际只存储了5个,也就是后5000轮的结果。在训练之初,确实会将模型从第一轮训练结果开始储存,但是随着模型的逐步增多,程序会按顺序丢弃那些稍早存储的模型。而mnist_evaluate.py程序是一个单独是评估程序,评估了保存的模型在验证数据集和测试数据集上的正确率并将正确率输出,这里没有设计代码的展示,我们只要知道这段代码没有涉及迭代训练的过程,所以运行起来很快。最后保存的模型在验证集上可以得到98%的正确率,在测试集上能够获得98.5%的正确率。总结:本节内容讲述了如何使用Tensorflow进行手写数字体识别的模型持久化,模型持久化的作用是非常明显的,在今后的学习中运用的非常多,读者可以在这方面多加联系。

关注小鲸融创,一起深度学习金融科技!

持久化的MNIST手写字识别,实现的那个MLP网络模型可以在MNIST数据集上得到98%的正确率