利用softmax函数对mnist数据集简单分类

mnist数据集的特点

  1. 每一张图片包含28**28个像素,我们把这一个数组展开成一个向量,长度是28*28=784。因此在 MNIST训练数据集中mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用 来索引图片,第二个维度数字用来索引每张图片中的像素点。图片里的某个像素的强度值介于0-1 之间。
  2. MNIST数据集的标签是介于0-9的数字,我们要把标签转化为“one-hot vectors”。一个onehot向量除了某一位数字是1以外,其余维度数字都是0,比如标签0将表示为([1,0,0,0,0,0,0,0,0,0]) ,标签3将表示为([0,0,0,1,0,0,0,0,0,0]) 。
  3. 因此, mnist.train.labels 是一个 [60000, 10] 的数字矩阵。
    例如,下面这幅图,代表的数字为5042利用softmax函数对mnist数据集简单分类

softmax函数:

  1. 我们知道MNIST的结果是0-9,我们的模型可能推测出一张图片是数字9的概率是80%,是数字8 的概率是10%,然后其他数字的概率更小,总体概率加起来等于1。这是一个使用softmax回归模型的经典案例。softmax模型可以用来给不同的对象分配概率。
    利用softmax函数对mnist数据集简单分类

程序如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import  input_data

# 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

# 定义批次batch_size,一次性放入100张图片
batch_size = 100
# 计算一个有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

# 创建一个简单的神经网络
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros(10))
prediction = tf.nn.softmax(tf.matmul(x, W) + b)

# 二次代价函数
loss = tf.reduce_mean(tf.square(y - prediction))

# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# 初始化变量
init = tf.initialize_all_variables()

# 预测的结果
# tf.argmax()返回最大值所在的列
# 结果存放在一个bool型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))

# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})
        acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y: mnist.test.labels})
        print("Iter" + str(epoch) + ",Testing Accuracy" + str(acc))

运行结果如下:

Iter0,Testing Accuracy0.7488
Iter1,Testing Accuracy0.8331
Iter2,Testing Accuracy0.8592
Iter3,Testing Accuracy0.8707
Iter4,Testing Accuracy0.8779
Iter5,Testing Accuracy0.8814
Iter6,Testing Accuracy0.885
Iter7,Testing Accuracy0.8884
Iter8,Testing Accuracy0.8917
Iter9,Testing Accuracy0.8936
Iter10,Testing Accuracy0.8962
Iter11,Testing Accuracy0.8968
Iter12,Testing Accuracy0.8982
Iter13,Testing Accuracy0.8994
Iter14,Testing Accuracy0.9009
Iter15,Testing Accuracy0.9023
Iter16,Testing Accuracy0.9031
Iter17,Testing Accuracy0.9037
Iter18,Testing Accuracy0.9044
Iter19,Testing Accuracy0.9053
Iter20,Testing Accuracy0.9053

准确率大概在90%左右。