一、程序介绍
1、包导入
# Author : Hellcat
# Time : 17-12-29 import os
import numpy as np
np.set_printoptions(threshold=np.inf)
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
from tensorflow.examples.tutorials.mnist import input_data
2、TFRecord录入格式转换
TFRecord的录入格式是确定的,整数或者二进制,在train函数中能查看所有可以接受类型
def _int64_feature(value):
"""生成整数数据属性"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):
"""生成字符型数据属性"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
3、TFRecord文件写入测试
将mnist数据以每张图片为单位写入同一个TFR文件,
实际上就是每次把一个图片相关信息都写入,注意文件类型,二级制数据需要以string的格式保存
def TFRecord_write():
"""将mnist数据集写入TFR文件"""
mnist = input_data.read_data_sets('./Data_Set/Mnist_data',
dtype=tf.uint8,one_hot=True) images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1] # 784
num_examples = mnist.train.num_examples # 55000 # TFRecords文件地址
filename = './TFRecord_Output/mnist_train.tfrecords' if not os.path.exists('./TFRecord_Output/'):
os.makedirs('./TFRecord_Output/') # 创建一个writer书写文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
# 提取单张图像矩阵并转换为字符串
image_raw = images[index].tostring()
# 将单张图片相关数据写入TFR文件
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(np.argmax(labels[index])),
'img_raw': _bytes_feature(image_raw)
}))
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
调用,
if __name__=='__main__':
TFRecord_write()
输出如下,
4、TFRecord文件读取测试
实际的读取基本单位和存入的基本单位是一一对应的,当然也可以复数读取,但是由于tf后续有batch拼接的函数,所以意义不大
def TFRecord_read():
"""从TFR文件读取mnist数据集合"""
# 创建一个reader读取文件
reader = tf.TFRecordReader()
# 创建读取文件队列维护文件列表
filename_queue = tf.train.string_input_producer(['./TFRecord_Output/mnist_train.tfrecords']) # 读取数据
# 每次读取一个
# _, serialized_example = reader.read(filename_queue)
# 每次读取多个
_, serialized_example = reader.read_up_to(filename_queue,10) # 解析样例
# 解析函数选择必须和上面读取函数选择相一致
# 解析单个样例
# features = tf.parse_single_example(
# 同时解析所有样例
features = tf.parse_example(
serialized_example,
features={
'img_raw': tf.FixedLenFeature([],tf.string),
'pixels': tf.FixedLenFeature([],tf.int64),
'label': tf.FixedLenFeature([],tf.int64),
})
# 解析二进制数据格式,将之按照uint8格式解析
images = tf.decode_raw(features['img_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32) batch_size = 2
capacity = 1000 + 3 * batch_size images.set_shape([10,784])
labels.set_shape(10)
pixels.set_shape(10)
image_batch, label_batch, pixel_batch = tf.train.batch(
[images, labels, pixels], batch_size=batch_size, capacity=capacity)
# 线程控制器
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord) # 这里指代的是读取数据的线程,如果不加的话队列一直挂起
for i in range(10):
# print(images, labels, pixels)
# print(sess.run(images))
image, label, pixel = sess.run([image_batch,label_batch,pixel_batch])
# image, label, pixel = sess.run([images,labels,pixels])
print(image.shape,label,pixel)
输出,
拼接batch尺寸为2,每次读取10个数据
可以看到,这里batch尺寸指定的实际上是读取次数
(2, 10, 784)
[[7 3 4 6 1 8 1 0 9 8]
[0 3 1 2 7 0 2 9 6 0]][[784 784 784 784 784 784 784 784 784 784]
[784 784 784 784 784 784 784 784 784 784]]
……
注意读取数目和解析数目选择的函数是要对应的,
# 读取数据
# 每次读取一个
# _, serialized_example = reader.read(filename_queue)
# 每次读取多个,这里指定10个
_, serialized_example = reader.read_up_to(filename_queue,10) # 解析样例
# 解析函数选择必须和上面读取函数选择相一致
# 解析单个样例
# features = tf.parse_single_example()
# 同时解析所有样例
features = tf.parse_example()
值得注意的是这句,
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
虽然后续未必会调用(coord实际上还是会调用用于协调停止),但实际上控制着队列的数据读取部分的启动,注释掉后会导致队列有出无进进而挂起。
5、TFRecord文件批量生成
def TFR_gen():
"""TFR样例数据生成"""
# 定义写多少个文件(数据量大时可以写入多个文件加速)
num_shards = 2
# 定义每个文件中放入多少数据
instances_per_shard = 2
for i in range(num_shards):
file_name = './TFRecord_Output/data.tfrecords-{}-of-{}'.format(i,num_shards)
writer = tf.python_io.TFRecordWriter(file_name)
for j in range(instances_per_shard):
example = tf.train.Example(features=tf.train.Features(feature={
'i':_int64_feature(i),
'j':_int64_feature(j),
'list':_bytes_feature(bytes([1,2,3]))
}))
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
输出如下,
6、TFRecord文件读取测试
def TFR_load():
"""批量载入TFR数据"""
# 匹配文件名
files = tf.train.match_filenames_once('./TFRecord_Output/data.tfrecords-*')
import glob
# files = glob.glob('./TFRecord_Output/data.tfrecords-*')
# 载入文件名
filename_queue = tf.train.string_input_producer(files,shuffle=True) reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'i':tf.FixedLenFeature([],tf.int64),
'j':tf.FixedLenFeature([],tf.int64),
'list':tf.FixedLenFeature([],tf.string)
})
'''
# tf.train.match_filenames_once操作中产生了变量
# 值得注意的是局部变量,需要用下面的初始化函数初始化
sess.run(tf.local_variables_initializer())
print(sess.run(files))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(6):
print(sess.run([features['i'],features['j']]))
coord.request_stop()
coord.join(threads)
''' example, label, array = features['i'], features['j'], features['list']
# 每个batch的中样例的个数
batch_size = 3
# 队列中样例的个数
capacity = 1000 + 3 * batch_size suffer = False
# batch操作实际代指的就是数据读取和预处理操作
if suffer is not True:
example_batch, label_batch, array_batch = tf.train.batch(
[example, label, array], batch_size=batch_size, capacity=capacity)
else:
# 不同线程处理各自的文件
# 随机包含各个线程选择文件名的随机和文件内部数据读取的随机
example_batch, label_batch, array_batch = tf.train.shuffle_batch(
[example, label, array], batch_size=batch_size, capacity=capacity,
min_after_dequeue=30) sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 这里指代的是读取数据的线程,如果不加的话队列一直挂起
for i in range(2):
cur_example_batch, cur_label_batch, cur_array_batch = sess.run([example_batch, label_batch, array_batch])
print(cur_example_batch, cur_label_batch, cur_array_batch) coord.request_stop()
coord.join(threads)
注意下面介绍,
# tf.train.match_filenames_once操作中产生了变量
# 值得注意的是局部变量,需要用下面的初始化函数初始化
sess.run(tf.local_variables_initializer())
batch生成的两个函数如下,
suffer = False
# batch操作实际代指的就是数据读取和预处理操作
if suffer is not True:
example_batch, label_batch, array_batch = tf.train.batch(
[example, label, array], batch_size=batch_size, capacity=capacity)
else:
# 不同线程处理各自的文件
# 随机包含各个线程选择文件名的随机和文件内部数据读取的随机
example_batch, label_batch, array_batch = tf.train.shuffle_batch(
[example, label, array], batch_size=batch_size, capacity=capacity,
min_after_dequeue=30)
- 单一文件多线程,那么选用tf.train.batch(需要打乱样本,有对应的tf.train.shuffle_batch)
- 多线程多文件的情况,一般选用tf.train.batch_join来获取样本(打乱样本同样也有对应的tf.train.shuffle_batch_join使用)
二、batch和batch_join的说明
1、文件准备
$ echo -e "Alpha1,A1\nAlpha2,A2\nAlpha3,A3" > A.csv
$ echo -e "Bee1,B1\nBee2,B2\nBee3,B3" > B.csv
$ echo -e "Sea1,C1\nSea2,C2\nSea3,C3" > C.csv
$ cat A.csv
Alpha1,A1
Alpha2,A2
Alpha3,A3
2、单个Reader,单个样本
import tensorflow as tf
# 生成一个先入先出队列和一个QueueRunner
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
# 定义Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
# 运行Graph
with tf.Session() as sess:
coord = tf.train.Coordinator() #创建一个协调器,管理线程
threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。
for i in range(10):
print example.eval() #取样本的时候,一个Reader先从文件名队列中取出文件名,读出数据,Decoder解析后进入样本队列。
coord.request_stop()
coord.join(threads)
# outpt
# Alpha1
# Alpha2
# Alpha3
# Bee1
# Bee2
# Bee3
# Sea1
# Sea2
# Sea3
# Alpha1
3、单个Reader,多个样本
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
## filenames = tf.train.match_filenames_once('.\data\*.csv')
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
# 使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。Decoder解后数据会进入这个队列,再批量出队。
# 虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。
example_batch, label_batch = tf.train.batch(
[example, label], batch_size=5)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
print example_batch.eval()
coord.request_stop()
coord.join(threads)
# output
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']
# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']
# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']
# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
4、多Reader,多个样本
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['null'], ['null']]
example_list = [tf.decode_csv(value, record_defaults=record_defaults)
for _ in range(2)] # Reader设置为2
# 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。
example_batch, label_batch = tf.train.batch_join(
example_list, batch_size=5)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
print example_batch.eval()
coord.request_stop()
coord.join(threads) # output
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']
# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']
# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']
# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']
# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
tf.train.batch
与tf.train.shuffle_batch'
数是单个Reader读取,但是可以多线程。tf.train.batch_join'
和tf.train.shuffle_batch_join
可设置多Reader读取,每个Reader使用一个线程。至于两种方法的效率,单Reader时,2个线程就达到了速度的极限。多Reader时,2个Reader就达到了极限。所以并不是线程越多越快,甚至更多的线程反而会使效率下降。
在这个例子中, 虽然只使用了一个文件名队列, 但是TensorFlow依然能保证多个文件阅读器从同一次迭代(epoch)的不同文件中读取数据,知道这次迭代的所有文件都被开始读取为止。(通常来说一个线程来对文件名队列进行填充的效率是足够的)
另一种替代方案是: 使用tf.train.shuffle_batch
函数,设置num_threads
的值大于1。 这种方案可以保证同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件。这种方案的优点是:
- 避免了两个不同的线程从同一个文件中读取同一个样本。
- 避免了过多的磁盘搜索操作。
简单来说,
单一文件多线程,那么选用tf.train.batch(需要打乱样本,有对应的tf.train.shuffle_batch)
多线程多文件的情况,一般选用tf.train.batch_join来获取样本(打乱样本同样也有tf.train.shuffle_batch_join)
『TensorFlow』读书笔记_TFRecord学习的更多相关文章
-
『TensorFlow』读书笔记_降噪自编码器
『TensorFlow』降噪自编码器设计 之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...
-
『TensorFlow』读书笔记_VGGNet
VGGNet网络介绍 VGG系列结构图, 『cs231n』卷积神经网络工程实践技巧_下 1,全部使用3*3的卷积核和2*2的池化核,通过不断加深网络结构来提升性能. 所有卷积层都是同样大小的filte ...
-
『TensorFlow』读书笔记_ResNet_V2
『PyTorch × TensorFlow』第十七弹_ResNet快速实现 要点 神经网络逐层加深有Degradiation问题,准确率先上升到饱和,再加深会下降,这不是过拟合,是测试集和训练集同时下 ...
-
『TensorFlow』读书笔记_SoftMax分类器
开坑之前 今年3.4月份的时候就买了这本书,同时还买了另外一本更为浅显的书,当时读不懂这本,所以一度以为这本书很一般,前些日子看见知乎有人推荐它,也就拿出来翻翻看,发现写的的确蛮好,只是稍微深一点,当 ...
-
『TensorFlow』读书笔记_多层感知机
多层感知机 输入->线性变换->Relu激活->线性变换->Softmax分类 多层感知机将mnist的结果提升到了98%左右的水平 知识点 过拟合:采用dropout解决,本 ...
-
『TensorFlow』读书笔记_简单卷积神经网络
如果你可视化CNN的各层级结构,你会发现里面的每一层神经元的激活态都对应了一种特定的信息,越是底层的,就越接近画面的纹理信息,如同物品的材质. 越是上层的,就越接近实际内容(能说出来是个什么东西的那些 ...
-
『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_上
完整项目见:Github 完整项目中最终使用了ResNet进行分类,而卷积版本较本篇中结构为了提升训练效果也略有改动 本节主要介绍进阶的卷积神经网络设计相关,数据读入以及增强在下一节再与介绍 网络相关 ...
-
『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_下
数据读取部分实现 文中采用了tensorflow的从文件直接读取数据的方式,逻辑流程如下, 实现如下, # Author : Hellcat # Time : 2017/12/9 import os ...
-
『TensorFlow』读书笔记_AlexNet
网络结构 创新点 Relu激活函数:效果好于sigmoid,且解决了梯度弥散问题 Dropout层:Alexnet验证了dropout层的效果 重叠的最大池化:此前以平均池化为主,最大池化避免了平均池 ...
随机推荐
-
ajax提交Form
Jquery的$.ajax方法可以实现ajax调用,要设置url,post,参数等. 如果要提交现有Form需要写很多代码,何不直接将Form的提交直接转移到ajax中呢. 以前的处理方法 如Form ...
-
5 HandlerIterator处理程序迭代器类——Live555源码阅读(一)基本组件类
这是Live555源码阅读的第一部分,包括了时间类,延时队列类,处理程序描述类,哈希表类这四个大类. 本文由乌合之众 lym瞎编,欢迎转载 my.oschina.net/oloroso Handler ...
-
This application failed to start because it could not find or load the Qt platform plugin “windows”错误解决方法
这是一个困扰我很久的问题,关于Qt下生成的exe文件在没有安装Qt的机器上无法运行的问题.Qt是编写C++图形界面的一个很好工具,比MFC来的直观.可是,Qt的安装却是一个让人头疼的事情.早在上个学期 ...
-
Shell curl 和 wget 使用代理IP
Linux Shell 提供两个非常实用的命令来爬取网页,它们分别是 curl 和 wget curl 和 wget 使用代理 curl 支持 http.https.socks4.socks5 wge ...
-
jemter分布式部署及linux下分布式脚本执行
jmeter进行接口性能测试,占用内存较大,在模拟千万计并发用户时,使用分布式部署进行分压测试. 操作步骤:选择一台机器作为调度机,其他机器作为执行机 一.jmeter分布式部署 前提条件:A.执行机 ...
-
第十四届智能车培训 PLL锁相环
什么是锁相环? PLL(Phase Locked Loop): 为锁相回路或锁相环,用来统一整合时脉讯号,使高频器件正常工作,如内存的存取资料等.PLL用于振荡器中的反馈技术. 许多电子设备要正常工作 ...
-
maven整合ssh框架笔记
具体工程会上传文件sshpro <?xml version="1.0" encoding="UTF-8"?> <web-app xmlns:x ...
-
-bash: ls: No such file or directory 产生的原因及修改方法
ubuntu出现如下错误: { Welcome to Ubuntu 16.04.5 LTS (GNU/Linux 4.15.0-42-generic x86_64) * Documentation: ...
-
如何使用jqueryUi的datepicker日历控件?
参考: http://www.jb51.net/article/85007.htm 这里的日历控件是, 基于jquery的jqureyui中的一个 widget. 需要js 文件: 外部的js文件, ...
-
Django Admin实现三级联动(省市区)
通过自定义Admin的模板文件实现省市区的三级联动.要求创建记录时,根据省>市>区的顺序选择依次显示对应数据. 修改记录时默认显示已存在的数据. Model class Member(mo ...