MNIST数据集的学习笔记一
源码出处:https://www.cnblogs.com/yinzm/p/7110870.html,由于源码注释过多,没有一定的理解,自己做了一下小白笔记([email protected]邮箱,如博文有错,恳请联系,谢谢!)。
# -*- coding: utf-8 -*- # 由于书上使用的TensorFlow版本比较旧,所以有些代码有所改动, # 本人使用的TensorFlow版本为1.2.0 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data
MNIST数据集可在官网获取,解压后如下:
(1).6w张28*28像素点的0-9的手写数字图片 训练用。可将其图片像素信息保存为一个784(28*28)长度的一位数组,用于储存信息。打开一张MNIST数据集的信息,会发现黑色背景为0,储存在一维数组中,占大多数。字体白色为1,0到1为较两者中间色度区域。
(2).6w张图片对应的label,这些以一位数组储存。如【0,0,0,0,1,0,0,0,0,0】表示从0~9的概率分别为0,0,0,0,1······,当然就是指这个标签在为4的概率为100%。对应图片手写数字为4.
(3)(4),同样格式,1W张手写数字image 和 label 用于test。
用代码下载的数据集保存在项目文件夹下的./data文件夹中,在主程序中,也必有一行代码用于获取MNIST数据集,如下。input_data模块已经从上面from了,其方法read_data_sets()第一个参数规定了MNIST数据集储存的位置,类似在D:\pycharmProjects\data中,第二个参数告知以one_hot(独热编码)方式读取。当程序运行时检测到该路径没有MNIST数据集,会自动到官网下载,并划分为TRAIN 、VALIDATION和TEST两个子集。
mnist = input_data.read_data_sets("./data", one_hot=True)
#主程序 def main(argv=None): # 声明处理MNIST数据集的类,这个类在初始化时会自动下载数据。 mnist = input_data.read_data_sets("./data", one_hot=True) train(mnist) print("now trian has loaded") print("train data size :",mnist.train.num_examples) print("validation data size :", mnist.validation.num_examples) print("test data size :", mnist.test.num_examples) # TensorFlow提供的一个主程序入口,tf.app.run会调用上面定义的main函数 if __name__ == "__main__": tf.app.run()
运行结果:
此时打印出训练子集、验证子集、测试子集样本数分别为55000、5000、10000。
打开数据内部,可以看到labels[0]的3概率为100%,故对应图片的数字应该为3。
打开images[0],可以看到第一张图片的二维数组是这样的。
当然,如果电脑显卡不行,可以减少程序的运行量,如下。使用BATCH定义了一个数字,使用next_batch()方法取出BATCH大小的数据的图片和标签,很明显xs为图片数据,ys为对应的标签数据。
BATCH_SIZE = 200 xs, ys = mnist.train.next_batch(BATCH_SIZE) print("xs shape : ", xs.shape) <<< xs shape : 200 784 print("ys shape : ", ys.shape) <<< ys shape : 200 10