MNIST数据集的学习笔记一

 源码出处:https://www.cnblogs.com/yinzm/p/7110870.html,由于源码注释过多,没有一定的理解,自己做了一下小白笔记([email protected]邮箱,如博文有错,恳请联系,谢谢!)。

# -*- coding: utf-8 -*-
# 由于书上使用的TensorFlow版本比较旧,所以有些代码有所改动,
# 本人使用的TensorFlow版本为1.2.0

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

 MNIST数据集可在官网获取,解压后如下:

MNIST数据集的学习笔记一

(1).6w张28*28像素点的0-9的手写数字图片  训练用。可将其图片像素信息保存为一个784(28*28)长度的一位数组,用于储存信息。打开一张MNIST数据集的信息,会发现黑色背景为0,储存在一维数组中,占大多数。字体白色为1,0到1为较两者中间色度区域。

(2).6w张图片对应的label,这些以一位数组储存。如【0,0,0,0,1,0,0,0,0,0】表示从0~9的概率分别为0,0,0,0,1······,当然就是指这个标签在为4的概率为100%。对应图片手写数字为4.

(3)(4),同样格式,1W张手写数字image 和 label 用于test。

用代码下载的数据集保存在项目文件夹下的./data文件夹中,在主程序中,也必有一行代码用于获取MNIST数据集,如下。input_data模块已经从上面from了,其方法read_data_sets()第一个参数规定了MNIST数据集储存的位置,类似在D:\pycharmProjects\data中,第二个参数告知以one_hot(独热编码)方式读取。当程序运行时检测到该路径没有MNIST数据集,会自动到官网下载,并划分为TRAIN 、VALIDATION和TEST两个子集。

mnist = input_data.read_data_sets("./data", one_hot=True)
#主程序
def main(argv=None):
    # 声明处理MNIST数据集的类,这个类在初始化时会自动下载数据。
    mnist = input_data.read_data_sets("./data", one_hot=True)
    train(mnist)
    print("now trian has loaded")
    print("train data size :",mnist.train.num_examples)
    print("validation data size :", mnist.validation.num_examples)
    print("test data size :", mnist.test.num_examples)
# TensorFlow提供的一个主程序入口,tf.app.run会调用上面定义的main函数
if __name__ == "__main__":
    tf.app.run()

 运行结果:

MNIST数据集的学习笔记一

此时打印出训练子集、验证子集、测试子集样本数分别为55000、5000、10000。

打开数据内部,可以看到labels[0]的3概率为100%,故对应图片的数字应该为3。

 MNIST数据集的学习笔记一

 打开images[0],可以看到第一张图片的二维数组是这样的。MNIST数据集的学习笔记一

 当然,如果电脑显卡不行,可以减少程序的运行量,如下。使用BATCH定义了一个数字,使用next_batch()方法取出BATCH大小的数据的图片和标签,很明显xs为图片数据,ys为对应的标签数据。

BATCH_SIZE = 200
xs, ys = mnist.train.next_batch(BATCH_SIZE)
print("xs shape : ", xs.shape)   <<< xs shape : 200 784 
print("ys shape : ", ys.shape)   <<< ys shape : 200 10 

MNIST数据集的学习笔记一