在tensorflow上训练cifar10数据集

运行环境:tensorflow2.0

概述

对CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题,其任务是对一组大小为32x32的RGB图像进行分类,这些图像涵盖了10个类别:
飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船以及卡车。

在tensorflow上训练cifar10数据集

数据下载

在tensorflow上训练cifar10数据集


模型架构

模型是一个多层架构,由卷积层和非线性层(nonlinearities)交替多次排列后构成。这些层最终通过全连通层对接到softmax分类器上。

代码组织

代码位于tensorflow/models/image/cifar10/.

文件 作用
cifar10_input.py 读取本地CIFAR-10的二进制文件格式的内容。
cifar10.py 建立CIFAR-10的模型。
cifar10_train.py 在CPU或GPU上训练CIFAR-10的模型。
cifar10_multi_gpu_train.py 在多GPU上训练CIFAR-10的模型。
cifar10_eval.py 评估CIFAR-10模型的预测性能。
代码下载:

在tensorflow上训练cifar10数据集

CIFAR-10 网络模型部分的代码位于cifar10.py.完整的训练图中包含约765个操作。

运行cifar10_train.py,原代码迭代1000000,需要3,4个小时,

将其修改为迭代50000

在tensorflow上训练cifar10数据集

运行结果:

在tensorflow上训练cifar10数据集

cifar10_train.py 会周期性的在检查点文件保存模型中的所有参数,参数信息保存在tmp/cifar_10文件夹下

在tensorflow上训练cifar10数据集

可以通过tensorboard将训练过程可视化,可以通过cifar10_train.py中的SummaryWriter周期性的获取并显示这些数据。

可以在另一部分数据集上来评估训练模型的性能。脚本文件cifar10_eval.py对模型进行了评估

运行结果

在tensorflow上训练cifar10数据集

结果显示准确率为85.5%

在tensorflow上训练cifar10数据集

相关文件保存在tmp/cifar10_eval下