项目已上传至 GitHub —— file-queue
生成样例数据
先生成 TFRecord 格式的样例数据,Example 的结构如下,表示第1个文件中的第1个数据
{
'i':0,
'j':0
}
生成数据的代码如下(以下代码都实现自《TensorFlow:实战Google深度学习框架》)
import tensorflow as tf
# 创建TFRecord文件的帮助函数
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 模拟海量数据情况下将数据写入不同的文件
num_shards = 2 # 总共写入多少个文件
instances_per_shard = 2 # 每个文件有多少数据
for i in range(num_shards):
# 按0000n-of-0000m的后缀区分文件。n代表当前文件编号,m代表文件总数
filename = ('data/data.tfrecords-%.5d-of-%.5d' % (i, num_shards))
writer = tf.python_io.TFRecordWriter(filename)
# 将数据封装成Example结构并写入TFRecord文件
for j in range(instances_per_shard):
example = tf.train.Example(
features=tf.train.Features(feature={
'i': _int64_feature(i),
'j': _int64_feature(j)
}))
writer.write(example.SerializeToString())
writer.close()
运行后会在 data 文件夹下生成两个文件,文件的命名后缀为 0000n-of-0000m,n代表当前文件编号,m代表文件总数
data/
data.tfrecords-00000-of-00002
data.tfrecords-00001-of-00002
读取文件数据
文件队列的生成主要使用两个函数
- tf.train.match_filenames_once():获取符合正则表达式的文件列表
- tf.train.string_input_producer():用文件列表创建一个输入队列
通过设置 shuffle 参数为 True,string_input_producer 会将文件的入队顺序打乱,所以出队顺序是随机的。随机打乱文件顺序和入队操作会跑在一个单独的线程上,不会影响出队的速度
当输入队列中的所有文件都处理完后,它会将文件列表中的文件重新加入队列。可以通过设置 num_epochs 参数来限制加载初始文件列表的最大轮数
读取文件队列数据的代码如下
import tensorflow as tf
# 获取文件列表
files = tf.train.match_filenames_once('data/data.tfrecords-*')
# 创建文件输入队列
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# 读取并解析Example
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'i': tf.FixedLenFeature([], tf.int64),
'j': tf.FixedLenFeature([], tf.int64)
})
with tf.Session() as sess:
# 使用match_filenames_once需要用local_variables_initializer初始化一些变量
sess.run(
[tf.global_variables_initializer(),
tf.local_variables_initializer()])
# 打印文件名
print(sess.run(files))
# 用Coordinator协同线程,并启动线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# 获取数据
for i in range(6):
print(sess.run([features['i'], features['j']]))
coord.request_stop()
coord.join(threads)
这里需要使用 tf.local_variables_initializer() 初始化 tf.train.match_filenames_once() 中的变量,否则会报错
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value matching_filenames
运行结果如下
$ python read.py
[b'data/data.tfrecords-00000-of-00002'
b'data/data.tfrecords-00001-of-00002']
[0, 0]
[0, 1]
[1, 0]
[1, 1]
[0, 0]
[0, 1]
最后两个输出结果是第一个文件的第二遍输出,可知 string_input_producer 函数将初始文件列表重新加入了队列中
组合样例数据
可以使用两种函数组合样例数据,它们出队时得到的是一个 batch 的样例,它们的区别在于 shuffle_batch 函数会将数据顺序打乱
- tf.train.batch()
- tf.train.shuffle_batch()
使用 tf.train.batch() 的方法如下
import tensorflow as tf
# 获取文件列表
files = tf.train.match_filenames_once('data/data.tfrecords-*')
# 创建文件输入队列
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# 读取并解析Example
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'i': tf.FixedLenFeature([], tf.int64),
'j': tf.FixedLenFeature([], tf.int64)
})
# i代表特征向量,j代表标签
example, label = features['i'], features['j']
# 一个batch中的样例数
batch_size = 3
# 文件队列中最多可以存储的样例个数
capacity = 1000 + 3 * batch_size
# 组合样例
example_batch, label_batch = tf.train.batch(
[example, label], batch_size=batch_size, capacity=capacity)
with tf.Session() as sess:
# 使用match_filenames_once需要用local_variables_initializer初始化一些变量
sess.run(
[tf.global_variables_initializer(),
tf.local_variables_initializer()])
# 用Coordinator协同线程,并启动线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# 获取并打印组合之后的样例。真实问题中一般作为神经网路的输入
for i in range(2):
cur_example_batch, cur_label_batch = sess.run(
[example_batch, label_batch])
print(cur_example_batch, cur_label_batch)
coord.request_stop()
coord.join(threads)
运行结果如下
$ python batching.py
[0 0 1] [0 1 0]
[1 0 0] [1 0 1]
可以看到单个的数据被组织成 3 个一组的 batch
以下是使用 tf.train.shuffle_batch() 的方法,min_after_dequeue 参数限制了出队时队列中元素的最少个数,当队列元素个数太少时,随机的意义就不大了
example_batch,label_batch = tf.train.shuffle_batch(
[example,label],batch_size=batch_size,
capacity=capacity,min_after_dequeue=30)