TensorFlow 入门之手写识别(MNIST) softmax算法
softmax回归算法
我们知道MNIST的每一张图片都表示一个数字,从0到9。我们希望得到给定图片代表每个数字的概率。比如说,我们的模型可能推测一张包含9的图片代表数字9的概率是80%但是判断它是8的概率是5%(因为8和9都有上半部分的小圆),然后给予它代表其他数字的概率更小的值。
这是一个使用softmax回归(softmax regression)模型的经典案例。 softmax 模型可以用来给不同的对象分配概率。即使在之后,我们训练更加精细的模型时,最后一步也需要用softmax来分配概率。
这是一个使用softmax回归(softmax regression)模型的经典案例。softmax模型可以用来给不同的对象分配概率。即使在之后,我们训练更加精细的模型时,最后一步也需要用softmax来分配概率。
softmax回归(softmax regression)分两步:第一步
为了得到一张给定图片属于某个特定数字类的证据(evidence),我们对图片像素值进行加权求和。如果这个像素具有很强的证据说明这张图片不属于该类,那么相应的权值为负数,相反如果这个像素拥有有利的证据支持这张图片属于这个类,那么权值是正数。
下面的图片显示了一个模型学习到的图片上每个像素对于特定数字类的权值。红色代表负数权值,蓝色代表正数权值。
数字的特征
我们也需要加入一个额外的偏置量(bias),因为输入往往会带有一些无关的干扰量。因此对于给定的输入图片x它代表的是数字i的证据可以表示为
求和
其中Wi代表权重,bi 代表数字 i 类的偏置量,j 代表给定图片 x 的像素索引用于像素求和。然后用softmax函数可以把这些证据转换成概率 y:
激励函数
这的softmax可是看做是一个sigmoid形式的函数。把我们定义的线性函数的输出转换成我们想要的格式,也就是关于10个数字类的概率分布。因此,给定一张图片,它对于每一个数字的吻合度可以被softmax函数转换成为一个概率值。
归一化处理
展开等式右边的子式,可以得到:
softmax使用的公式
对于softmax回归模型可以用下面的图解释,对于输入的xs加权求和,再分别加上一个偏置量,最后再输入到softmax函数中:
softmax运行方式
如果把它写成一个等式,我们可以得到:
softmax数学表达式
我们也可以用向量表示这个计算过程:用矩阵乘法和向量相加。这有助于提高计算效率。(也是一种更有效的思考方式):
softmax矩阵表现形式
更进一步,可以写成更加紧凑的方式:
最终会使用的表达式
TensorFlow实现softmax
- # create a softmax regression
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets("/home/fly/TensorFlow/mnist", one_hot=True)
- x = tf.placeholder(tf.float32,[None, 784])
- W = tf.Variable(tf.zeros([784, 10]))
- b = tf.Variable(tf.zeros([10]))
- y = tf.nn.softmax(tf.matmul(x,W)+b)
- y_ = tf.placeholder(tf.float32,[None, 10])
- cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
- train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
- init = tf.initialize_all_variables()
- sess = tf.Session()
- sess.run(init)
- for i in range(1000):
- batch_xs, batch_ys = mnist.train.next_batch(100)
- sess.run(train_step, feed_dict = {x: batch_xs, y_: batch_ys})
- correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- print(sess.run(accuracy, feed_dict={x:mnist.test.images, y_: mnist.test.labels}))
Fly
2016.6