基于TENSORFLOW的手写数字识别

1.引言

随着人工智能的发展,人工智能已经广泛应用到各个领域,以Tensorflow框架为深度学习工具的应用已经相当广泛,卷积神经网络是一类包含卷积运算且具有深度结构的前馈神经网络,采用反向传播(Back Propagation,BP)算法对模型进行学习训练,手写字体识别模型LeNet5诞生于1994年,是最早的卷积神经网络之一。LeNet5通过巧妙的设计,利用卷积、参数共享、池化等操作提取特征,避免了大量的计算成本,最后再使用全连接神经网络进行分类识别,这个网络也是最近大量神经网络架构的起点。本文基于Tensorflow,结合深度学习框架,利用Softmax回归算法进行多分类,结合CNN卷积神经网络实现对手写数字体的识别。

LeNet5结构图:

基于TENSORFLOW的手写数字识别
MNIST数据集:
基于TENSORFLOW的手写数字识别

2.系统结构

TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理。Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,TensorFlow为张量从流图的一端流动到另一端计算过程。TensorFlow是将复杂的数据结构传输至人工智能神经网中进行分析和处理过程的系统。TensorFlow使用计算图表来执行其所有的计算。计算被表示为tf.Graph对象的一个实例,而其中的数据被表示为tf.Tensor对象,并使用tf.Operation对象对这样的张量对象进行操作。然后再使用tf.Session对象的会话中执行该图表。图中节点(Nodes)一般表示施加的数学操作,或者表示数据输入(feed in)的起点/输出(push out)的终点,或者是读取/写入持久变量(persistent variable)的终点,图中线/边(edges)则表示在节点间相互联系的多维数据组,即张量(tensor)。

卷积神经网络结构一般含有以下几层:输入层,卷积层,池化层,全连接层,DropOut层,输出层

输入层就是整个神经网络的输入,在本文的卷积神经网络中,输入就是一张图片,三维卷积神经网络的输入层接受一个四维数组: [样本数量,样本长,样本宽,样本深度(对应图片通道数)]。

卷积层的功能是对输入数据进行特征提取,抽象特征,每一个卷积层都是通过一个可调的卷积核与上一层特诊图进行卷积运算,再加上一个偏移量得到输出,再经过**函数得到结果。通过卷积我们可以逐步得到更高层次的特征。特征是不断进行提取和压缩的,最终能得到比较高层次特征,简言之就是对原始特征一步又一步的浓缩,最终得到的特征更可靠。利用最后一层特征可以做各种任务:比如分类、回归等。

池化层的作用主要是在保留主要特征的同时减少参数和计算量,达到降维的目的,去除冗余信息,简化网络复杂度。

全连接层在整个卷积神经网络中起到分类的作用,因为我们最终需要的结果是一个分类的结果,是一维的,通过全连接层我们将输入的多维数据转为一维数据输出,在 CNN 中,全连接常出现在最后几层,用于对前面设计的特征做加权和。比如 mnist,前面的卷积和池化相当于做特征工程,后面的全连接相当于做特征加权。

为了防止训练模型的过拟合,需要在卷积神经网络中添加一层DropOut层,该层随机丢弃部分参数,此机制将保证神经网络不会对训练样本过于匹配,这将帮助缓解过拟合问题。

SoftMax函数, softmax 用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,公式如下:基于TENSORFLOW的手写数字识别

基于TENSORFLOW的手写数字识别

SoftMax的交叉熵损失函数如下:

基于TENSORFLOW的手写数字识别

ReLu函数,使用ReLu**函数的原因是ReLu能更加有效率的梯度下降以及反向传播:避免了梯度爆炸和梯度消失问题,同时简化计算过程:没有了其他复杂**函数中诸如指数函数的影响;同时活跃度的分散性使得神经网络整体计算成本下降

基于TENSORFLOW的手写数字识别

ReLu公式如下:

基于TENSORFLOW的手写数字识别

3.实现代码

模型训练相关代码:

MNIST数据集导入代码:

基于TENSORFLOW的手写数字识别

训练参数的定义:

基于TENSORFLOW的手写数字识别

输入参数的定义:
基于TENSORFLOW的手写数字识别

卷积层定义:

定义一个输入为x,权值为W,偏置为b,给定步幅的卷积层,**函数是ReLU,padding设定为same模式

基于TENSORFLOW的手写数字识别

池化层定义:定义一个输入时x的maxpool层,卷积核为ksize 并且padding为SAME

基于TENSORFLOW的手写数字识别

权值定义:

基于TENSORFLOW的手写数字识别
偏执定义:

基于TENSORFLOW的手写数字识别

卷积神经网络定义:

基于TENSORFLOW的手写数字识别

损失函数及其它一些辅助值定义:

基于TENSORFLOW的手写数字识别

模型训练代码如下:

基于TENSORFLOW的手写数字识别

将训练好的模型保存代码:

基于TENSORFLOW的手写数字识别
最终可视化训练结果的代码如下:

基于TENSORFLOW的手写数字识别

基于PYQT的手写板类PaintBoard设计代码:

  1. 类初始化代码:新建PaintBoard类继承QWidget,之后执行两个函数,一个是数据的初始化函数,一个是视图的初始化函数

基于TENSORFLOW的手写数字识别

  1. 数据初始化函数:

基于TENSORFLOW的手写数字识别

该函数主要是新建QPixmap作为画板,新建QPainter作为绘图工具,定义了画笔颜色和橡皮擦等,同时定义了鼠标位置。

  1. 视图初始化函数

基于TENSORFLOW的手写数字识别

该函数主要是设置好画板的大小,我们将整个窗口作为画板。

  1. 定义清空画板函数

基于TENSORFLOW的手写数字识别

该函数将画板填充成白色之后调用update()方法触发重绘事件,达到画板更新

  1. 定义改变画笔颜色接口

基于TENSORFLOW的手写数字识别

  1. 定义改变画笔粗细接口

基于TENSORFLOW的手写数字识别

  1. 定义判断画板是否为空接口

基于TENSORFLOW的手写数字识别

  1. 定义获取画板内容接口

基于TENSORFLOW的手写数字识别

该函数获取画板内容后以QImage的形式返回。

  1. 重写paintEvent事件

基于TENSORFLOW的手写数字识别

该函数调用QPainter将QPixmap的图像绘画到程序窗口上

  1. 重写mousePressEvent事件

基于TENSORFLOW的手写数字识别

当鼠标按下时我们保存鼠标的位置

  1. 重写mouseMoveEvent事件

基于TENSORFLOW的手写数字识别

绘画的原理就是在上一个位置和下一个鼠标位置之间画线,之后根据有无橡皮擦调整画笔颜色,最终调用update()更新显示

  1. 重写mouseReleaseEvent事件

基于TENSORFLOW的手写数字识别

当鼠标松开时此时画板必定有内容 所以设置画板不为空,至此手写板的设计完成,接着就是设计主界面后嵌入手写板。

基于PYQT的主界面设计:

  1. 类初始化代码:新建MainWidget类继承QWidget,之后执行两个函数,一个是数据的初始化函数,一个是视图的初始化函数

基于TENSORFLOW的手写数字识别

  1. 数据初始化代码

基于TENSORFLOW的手写数字识别

主界面需要含有绘画板,所以这里要新建一个前面写的绘画板

  1. 视图初始化代码

先设置主界面大小和标题,之后新建水平布局作为整个窗口的主布局。

接着设置一个垂直子布局,包含绘画板和一个QLabel作为识别结果的提示。

接着再设置一个垂直子布局,用来包含按钮和一些功能部件。

基于TENSORFLOW的手写数字识别

基于TENSORFLOW的手写数字识别

基于TENSORFLOW的手写数字识别

  1. 定义填充颜色选择下拉列表的代码

基于TENSORFLOW的手写数字识别

该代码遍历颜色列表后将颜色填充进下拉列表,并选择黑色为默认颜色。

  1. 定义改变画笔颜色的回调代码

基于TENSORFLOW的手写数字识别

该代码主要是获取下拉框当前的颜色索引之后取出颜色字符串,调用绘画板的接口实现画笔颜色改变。

  1. 定义改变画笔粗细的回调代码

基于TENSORFLOW的手写数字识别

该代码获取spinbox的值之后调用绘画板接口改变画笔粗细。

  1. 定义退出按钮的响应函数:

基于TENSORFLOW的手写数字识别

  1. 定义识别按钮的响应函数:

代码详解见注释,这里要注意因为MNIST的数据集是黑底白字,所以我们的数据也要一致,且神经网络的输入时浮点型数据,这里涉及到一个背景反转和类型转换,之前就是没注意到黑底白字问题导致识别不好。

基于TENSORFLOW的手写数字识别

加载模型相关代码:
基于TENSORFLOW的手写数字识别

实验结果

模型训练结果:

基于TENSORFLOW的手写数字识别

训练集和测试集准确率:

基于TENSORFLOW的手写数字识别

SoftMax损失函数收敛过程:

基于TENSORFLOW的手写数字识别

实际测试识别结果:

基于TENSORFLOW的手写数字识别

数字0识别测试:

基于TENSORFLOW的手写数字识别

数字1识别测试:

基于TENSORFLOW的手写数字识别

数字2识别测试:

基于TENSORFLOW的手写数字识别

数字3识别测试:

基于TENSORFLOW的手写数字识别
数字4识别测试

基于TENSORFLOW的手写数字识别

数字5识别测试

基于TENSORFLOW的手写数字识别

数字6识别测试

基于TENSORFLOW的手写数字识别

数字7识别测试

基于TENSORFLOW的手写数字识别

数字8识别测试

基于TENSORFLOW的手写数字识别

数字9识别测试

基于TENSORFLOW的手写数字识别