在tensorflow上训练cifar10数据集
运行环境:tensorflow2.0
概述
对CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题,其任务是对一组大小为32x32的RGB图像进行分类,这些图像涵盖了10个类别:飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船以及卡车。
数据下载
模型架构
模型是一个多层架构,由卷积层和非线性层(nonlinearities)交替多次排列后构成。这些层最终通过全连通层对接到softmax分类器上。
代码组织
代码位于tensorflow/models/image/cifar10/
.
文件 | 作用 |
---|---|
cifar10_input.py |
读取本地CIFAR-10的二进制文件格式的内容。 |
cifar10.py |
建立CIFAR-10的模型。 |
cifar10_train.py |
在CPU或GPU上训练CIFAR-10的模型。 |
cifar10_multi_gpu_train.py |
在多GPU上训练CIFAR-10的模型。 |
cifar10_eval.py |
评估CIFAR-10模型的预测性能。 |
CIFAR-10 网络模型部分的代码位于cifar10.py
.完整的训练图中包含约765个操作。
运行cifar10_train.py,原代码迭代1000000,需要3,4个小时,
将其修改为迭代50000
运行结果:
cifar10_train.py
会周期性的在检查点文件中保存模型中的所有参数,参数信息保存在tmp/cifar_10文件夹下
可以通过tensorboard将训练过程可视化,可以通过cifar10_train.py
中的SummaryWriter
周期性的获取并显示这些数据。
可以在另一部分数据集上来评估训练模型的性能。脚本文件cifar10_eval.py
对模型进行了评估
运行结果
结果显示准确率为85.5%
相关文件保存在tmp/cifar10_eval下