tensorflow中使用mnist数据集训练全连接神经网络
——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师
前期准备:mnist数据集下载,并存入data目录:
文件列表:四个文件,分别为训练和测试集数据
Four files are available on 官网 http://yann.lecun.com/exdb/mnist/ :
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz:
training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz:
test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz:
test set labels (4542 bytes)
一、主要思路:
1、训练集输入数据X为28×28图像,和 Y_ onehot label
2、构建一个三层NN,input layer,one hidden layer,outputlayer
3、使用指数衰减学习率,交叉熵loss,移动平均loss构建NN
二、主要代码:
forward构建:
//mnist_forward.py
import tensorflow as tf
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
def get_weight(shape, regularizer):
w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
if regularizer != None:
tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
return w
def get_bias(shape):
b = tf.Variable(tf.zeros(shape))
return b
def forward(x,regularizer):
w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)
b1 = get_bias([LAYER1_NODE])
y1 = tf.nn.relu(tf.matmul(x, w1) + b1)
w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)
b2 = get_bias([OUTPUT_NODE])
y = tf.matmul(y1, w2) + b2
return y
backward构建:
//mnist_backward.py
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os
BATCH_SIZE = 200
LEARNING_RATE_BASE = 0.1
LEARNING_RATE_DECAY = 0.99
REGULARIZER = 0.0001
STEPS = 50000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"
def backward(mnist):
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
y = mnist_forward.forward(x, REGULARIZER)
global_step = tf.Variable(0, trainable=False)
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_,1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase = True
)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step = global_step)
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
train_op = tf.no_op(name='train')
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
for i in range(STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 1000 ==0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step = global_step)
def main():
mnist = input_data.read_data_sets("./data/", one_hot = True)
backward(mnist)
if __name__ == '__main__':
main()
tensorflow中使用mnist数据集训练全连接神经网络-学习笔记的更多相关文章
-
【TensorFlow/简单网络】MNIST数据集-softmax、全连接神经网络,卷积神经网络模型
初学tensorflow,参考了以下几篇博客: soft模型 tensorflow构建全连接神经网络 tensorflow构建卷积神经网络 tensorflow构建卷积神经网络 tensorflow构 ...
-
深度学习tensorflow实战笔记(1)全连接神经网络(FCN)训练自己的数据(从txt文件中读取)
1.准备数据 把数据放进txt文件中(数据量大的话,就写一段程序自己把数据自动的写入txt文件中,任何语言都能实现),数据之间用逗号隔开,最后一列标注数据的标签(用于分类),比如0,1.每一行表示一个 ...
-
TensorFlow之DNN(二):全连接神经网络的加速技巧(Xavier初始化、Adam、Batch Norm、学习率衰减与梯度截断)
在上一篇博客<TensorFlow之DNN(一):构建“裸机版”全连接神经网络>中,我整理了一个用TensorFlow实现的简单全连接神经网络模型,没有运用加速技巧(小批量梯度下降不算哦) ...
-
TensorFlow之DNN(一):构建“裸机版”全连接神经网络
博客断更了一周,干啥去了?想做个聊天机器人出来,去看教程了,然后大受打击,哭着回来补TensorFlow和自然语言处理的基础了.本来如意算盘打得挺响,作为一个初学者,直接看项目(不是指MINIST手写 ...
-
Tensorflow 多层全连接神经网络
本节涉及: 身份证问题 单层网络的模型 多层全连接神经网络 激活函数 tanh 身份证问题新模型的代码实现 模型的优化 一.身份证问题 身份证号码是18位的数字[此处暂不考虑字母的情况],身份证倒数第 ...
-
MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
-
Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)
基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html 摘要 在前面的博文中,我详细介绍了Caf ...
-
tensorflow读取本地MNIST数据集
tensorflow读取本地MNIST数据集 数据放入文件夹(不要解压gz): >>> import tensorflow as tf >>> from tenso ...
-
Keras入门——(1)全连接神经网络FCN
Anaconda安装Keras: conda install keras 安装完成: 在Jupyter Notebook中新建并执行代码: import keras from keras.datase ...
随机推荐
-
python动态获取对象的属性和方法 (转载)
首先通过一个例子来看一下本文中可能用到的对象和相关概念. #coding:utf-8 import sys def foo():pass class Cat(object): def __init__ ...
-
项目管理、测试管理、代码bug 管理
1.友盟统计 阿里旗下的产品 http://www.umeng.com/ 2.bugly 腾讯旗下的产品 http://bugly.qq.com/ 3.禅道 项目管理工具 需要部署到 ...
-
[HDOJ5451]Best Solver(乱搞)
题目:http://acm.hdu.edu.cn/showproblem.php?pid=5451 分析:A=5+2根号6 B=6-2根号6 n=1+2^x 那么A^n+B^n是整数 注意到0< ...
-
【转载】MyBatis之传入参数
原文地址:http://blog.csdn.net/liaoxiaohua1981/article/details/6862764 在MyBatis的select.insert.update.dele ...
-
Qt历史版本下载
今天找到一个Qt官方下载任意版本的链接,在这里分享给大家~ http://download.qt.io/archive/qt/ 里面有Qt的历史版本 原文链接http://www.donnyblog. ...
-
Spreadsheets
很水的一道题,提醒自己要认真,做的头都快晕了.考虑26的特殊情况. D - Spreadsheets Time Limit:10000MS Memory Limit:65536KB 6 ...
-
比列的数目更多,以便找到第一k小值
问题叙述性说明:现有n作为一个有序序列(2,3,9),(3,5,11,23),(1,4,7,9,15,17,20),(8,15,35,9),(20,30,40),第k小值. 问题分析:可用多路归并排序 ...
-
linux操作系统基础篇(五)
Linux网络以及rpm安装yum源的配置 1.Linux网络 1. 使用ifconfig命令来维护网络1) fconfig命令的功能:显示所有正在启动的网卡的详细信息或设定系统中网卡的IP地址.2) ...
-
Python3 socket网络编程(一)
Socket的定义 套接字是为特定网络协议(例如TCP/IP,ICMP/IP,UDP/IP等)套件对上的网络应用程序提供者提供当前可移植标准的对象.它们允许程序接受并进行连接,如发送和接受数据.为了建 ...
-
Java学习笔记44(多线程一:Thread类)
多线程的概念:略 多线程的目的:提高效率 主线程: package demo; //主线程 public class Demo { public static void main(String[] a ...