基于MNIST数据集使用TensorFlow训练一个包含一个隐含层的全连接神经网络

时间:2022-08-31 10:16:44

包含一个隐含层的全连接神经网络结构如下:

基于MNIST数据集使用TensorFlow训练一个包含一个隐含层的全连接神经网络

包含一个隐含层的神经网络结构图

以MNIST数据集为例,以上结构的神经网络训练如下:

#coding=utf-8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf # 加载数据
mnist = input_data.read_data_sets('/home/workspace/python/tf/data/mnist', one_hot=True)
"""
# 创建模型
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
"""
x = tf.placeholder(tf.float32, [None, 784])
W1 = tf.Variable(tf.truncated_normal([784, 500], stddev=0.1))
b1 = tf.Variable(tf.zeros([500]))
W2 = tf.Variable(tf.truncated_normal([500, 10], stddev=0.1))
b2 = tf.Variable(tf.zeros([10]))
layer1 = tf.nn.relu(tf.matmul(x, W1) + b1)
y = tf.matmul(layer1, W2) + b2 # 正确的样本标签
y_ = tf.placeholder(tf.float32, [None, 10]) # 损失函数选择softmax后的交叉熵,结果作为y的输出
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.InteractiveSession()
tf.global_variables_initializer().run() # 训练过程
for _ in range(5000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
if _%1000 == 0:
# 使用测试集评估准确率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print (sess.run(accuracy, feed_dict = {x: mnist.test.images,
y_: mnist.test.labels}))

注意:权重向量初始化时使用tf.truncated_normal,而不要使用tf.zeros

以上代码大概能得到97.98%的准确率。

软件版本


TensorFlow 1.0.1  +  Python 2.7.12

基于MNIST数据集使用TensorFlow训练一个包含一个隐含层的全连接神经网络的更多相关文章

  1. 基于MNIST数据集使用TensorFlow训练一个没有隐含层的浅层神经网络

    基础 在参考①中我们详细介绍了没有隐含层的神经网络结构,该神经网络只有输入层和输出层,并且输入层和输出层是通过全连接方式进行连接的.具体结构如下: 我们用此网络结构基于MNIST数据集(参考②)进行训 ...

  2. TensorFlow之DNN(二):全连接神经网络的加速技巧(Xavier初始化、Adam、Batch Norm、学习率衰减与梯度截断)

    在上一篇博客<TensorFlow之DNN(一):构建“裸机版”全连接神经网络>中,我整理了一个用TensorFlow实现的简单全连接神经网络模型,没有运用加速技巧(小批量梯度下降不算哦) ...

  3. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  4. 【TensorFlow&sol;简单网络】MNIST数据集-softmax、全连接神经网络,卷积神经网络模型

    初学tensorflow,参考了以下几篇博客: soft模型 tensorflow构建全连接神经网络 tensorflow构建卷积神经网络 tensorflow构建卷积神经网络 tensorflow构 ...

  5. 深度学习tensorflow实战笔记(1)全连接神经网络(FCN)训练自己的数据(从txt文件中读取)

    1.准备数据 把数据放进txt文件中(数据量大的话,就写一段程序自己把数据自动的写入txt文件中,任何语言都能实现),数据之间用逗号隔开,最后一列标注数据的标签(用于分类),比如0,1.每一行表示一个 ...

  6. TensorFlow之DNN(一):构建&OpenCurlyDoubleQuote;裸机版”全连接神经网络

    博客断更了一周,干啥去了?想做个聊天机器人出来,去看教程了,然后大受打击,哭着回来补TensorFlow和自然语言处理的基础了.本来如意算盘打得挺响,作为一个初学者,直接看项目(不是指MINIST手写 ...

  7. Tensorflow 多层全连接神经网络

    本节涉及: 身份证问题 单层网络的模型 多层全连接神经网络 激活函数 tanh 身份证问题新模型的代码实现 模型的优化 一.身份证问题 身份证号码是18位的数字[此处暂不考虑字母的情况],身份证倒数第 ...

  8. caffe中全卷积层和全连接层训练参数如何确定

    今天来仔细讲一下卷基层和全连接层训练参数个数如何确定的问题.我们以Mnist为例,首先贴出网络配置文件: name: "LeNet" layer { name: "mni ...

  9. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

随机推荐

  1. nodejs express下使用redis管理session

    Session实现原理 实现请求身份验证的方式很多,其中一种广泛接受的方式是使用服务器端产生的Session ID结合浏览器的Cookie实现对Session的管理,一般来说包括以下4个步骤: 服务器 ...

  2. MySql&period;Data&period;dll 不支持输出参数

    insert INTO stu(name) VALUES('maimai'); set @ReturnValue=@@IDENTITY; string sql="insert INTO st ...

  3. php调用empty出现错误Can&&num;39&semi;t use function return value in write context

    php调用empty出现错误Can't use function return value in write context 2012-10-28 09:33:22 | 11391次阅读 | 评论:0 ...

  4. https&colon;&sol;&sol;docs&period;mongodb&period;org&sol;manual&sol;reference&sol;operator&sol;aggregation&sol;unwind&sol;&num;examples

    https://docs.mongodb.org/manual/reference/operator/aggregation/unwind/#examples http://www.clusterdb ...

  5. 半同步半异步模式的实现 - MSMQ实现

    半同步半异步模式的实现 - MSMQ实现 所谓半同步半异步是指,在某个方法调用中,有些代码行是同步执行方式,有些代码行是异步执行方式,下面我们来举个例子,还是以经典的PlaceOrder来说,哈哈. ...

  6. &lbrack;MySQL&rsqb; mysql 的读写锁与并发控制

    1.无论何时只要有多个查询在同一时刻修改数据,都会产生并发控制的问题 2.讨论mysql在两个层面,服务器层和存储引擎层,如何并发控制读写 3.举了个mbox邮箱文件的例子,说如果有多个进程同时对mb ...

  7. Asp&period;Net Core MongoDB

    废话不说直接上代码: using MongoDB.Bson.Serialization.Attributes; namespace XL.Core.MongoDB { public interface ...

  8. mysql 5&period;6 解压缩版安装教程

    MySQL 5.6 for Windows 解压缩版配置安装 听语音 | 浏览:68011 | 更新:2014-03-05 12:28 | 标签:mysql 1 2 3 4 5 6 7 分步阅读 My ...

  9. 在node中使用jwt签发与验证token

    1.什么是token token的意思是“令牌”,是服务端生成的一串字符串,作为客户端进行请求的一个标识. token是在服务端产生的.如果前端使用用户名和密码向服务端发送请求认证,服务端认证成功,那 ...

  10. Ext自定义控件 - 自学ExtJS

    本文所有思想表达均为个人意见,如有疑义请多批评,如有雷同不甚荣幸. 转载请注明出处:Nutk'Z http://www.cnblogs.com/nutkz/p/3448801.html 在用到ExtJ ...