30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
文章目录
写在前头
本文与 30天 吃掉那只 tensorflow 的原文有较大的出入,只是借鉴他使用的数据集并且整个构建网络的过程;数据处理部分和训练网络部分都是自己设计的,测试部分我也使用了自己上网找到图片并进行了处理;仅供大家参考
1. Cifar10数据集的介绍、获取
cifar10 数据集使用了 10 种分类的训练数据;标签从 0-9
获取方式: 如下:
使用强大的 jupyterlab 工具来查看一下load_data
的用法;这个函数会自动帮你下载数据集(这种通过命令行直接从 tensorflow 官网中下载大约需要2个小时,可以通过别的方式获取,请自行查阅相关文章)
2. 训练集数据可视化
load_data
函数最终返回 2 个 tuple;所以我们用以下代码来接受返回值
可视化一下前三个训练数据
- 我们可以看到训练数据的前三个的标签为 6,9,9
- 也可以看到我们可视化出来的结果前三张图分别是:青蛙、卡车、卡车
![]()
3. 简单数据处理:将标签进行 one-hot 编码转换
对于多分类问题,我们一般会把标签转换成 one-hot 编码的形式,为了以后更容易计算;所以我们在这里用到了
to_categorical
函数来转换
还记得么y_train
的前三个标签是6,9,9,
现在变成了对应位置为 1 其他位置为0 的one-hot
编码,这是为了在后面的计算中对应的使用 softmax 函数计算出概率分布,并通过相应位置的概率分布计算损失;
4. 构建网络模型
可以用
model.summary()
来查看你已经建立的网络结构还有需要训练的参数
5. 模型训练
6. 训练数据可视化
先来看看 history 里有哪些可以可视化的数据:
发现有['loss', 'acc', 'val_loss', 'val_acc']
我们使用索引把他们的值分别拿出来并展示以下,他们每一组数据都存在一个列表里,我们就用这些数据可视化每一个epoch
的训练过程
把 loss 和 val_loss 呈现在一张图中
把 acc 和 val_acc 也放在同一个图中进行对比
两个图像表明,数据被训练的很好,也没有存在过拟合现象。
7. 数据评估
数据评估的准确率甚至略高于训练集,这也是很好的结果
下面,我从网上随便照一张图片,用我们训练好的模型来检测一下训练成果。下面是我找的网图;下面的代码中我将演示如何处理这张网图,然后用模型来进行预测。
7.1 数据处理
- 可以看到,刚读出来的图片是
3
个通道的彩图;我们上面训练的也使用的3
通道彩图;- 所以我们要对这个图片进行 resize;但是 resize 操作不能直接对 3 通道的图片做;所以:
- 我们按照 opencv 读图片的通道顺序
b, g, r
(注意不是 rgb) 使用cv2.split()
函数对数据解包;得到了每个通道之后我们分别做resize
操作,最后再用cv2.merge()
将三个通道叠加起来;这样我们就可以得到我们想要的结果了
7.2 将数据送到模型中测试
但是直接这样送入模型会报错;因为你用测试集进行测试的时候你的数据是 4 维的,
(10000, 32, 32, 3)
这个10000
代表的是10000
张测试图片;所以我们要进行测试,要把这个测试的数据升高一个维度,或者把n
张图片绑到一个数组里面送去测试,即把矩阵变成(n, 32, 32, 3)
这种模式;
为了简单起见,我直接用两张它自己作为测试集;
注意!!! 把这两张图片放在一个列表中之后,千万不要忘记将这个列表用 numpy 转换成一个矩阵;因为模型的输入只能是矩阵而不能是列表
按照概率分布中的结果,我们筛选出最大的值来看一看是不是我们想要的标签
使用.argmax()
函数返回数组中最大的值的索引;返回的位置是1
;我们回去看 cifar10 数据集中的 1 代表的是“汽车”
类
所以可以看出来,模型的训练效果还是不错的。模型的保存部分,大家可以翻看我的上篇文章。
写在后面
如有错误,敬请指正;欢迎交流
个人的微信号: