TensorFlow输入文件队列

时间:2021-10-16 10:00:26

        TensorFlow提供了tf.train.match_filenames_once函数来获取符合一个正则表达式的所有文件,得到的文件列表可以通过tf.train.string_input_producer函数来进行有效管理。

下面给出一个简单的程序来生成样例数据:

import tensorflow as tf
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#模拟海量数据情况下将数据写入不同文件,num_shards定义了总共写入多少个文件
num_shards = 2
#instances_per_shard定义了每个文件中有多少个数据
instances_per_shard = 2
for i in range(num_shards):
    #将数据分为多个文件时,可以将不同文件以类似000n-of-0000m的后缀区分。其中m代表数据总共被存在了多少个文件中,n代表当前编号。
    #式样的方式即方便通过了正则表达式获取文件列表,,又在文件名中加入了更多信息
    filename = ('Records/data.tfrecords-%.5d-of-%.5d' % (i, num_shards)) 
    # 将Example结构写入TFRecord文件。
    writer = tf.python_io.TFRecordWriter(filename)
    #将数据封装成Example结构并写入TFRecord文件
    for j in range(instances_per_shard):
    # Example结构仅包含当前样例属于第几个文件以及是当前文件的第几个样本。
        example = tf.train.Example(features=tf.train.Features(feature={
            'i': _int64_feature(i),
            'j': _int64_feature(j)}))
        writer.write(example.SerializeToString())
    writer.close()  

 程序运行之后,会在指定目录下生成俩个文件,data.tfrecords-00000-of-00002与data.tfrecords-00001-of-00002。每个文        件储存了俩个样例。之后。以下代码展示了tf.train.match_filenames_once函数与tf.train.string_input_producer函数使用方法:

#使用tf.train.match_filenames_once函数获取文件列表
files = tf.train.match_filenames_once("Records/data.tfrecords-*")
#通过tf.train.string_input_producer函数创建输入队列,输入队列中的文件列表为 tf.train.match_filenames_once函数获取的文件列表。
#这里将shuffle设为False来避免随机打乱读文件的顺序。但解决真实问题时一般会设为True。
filename_queue = tf.train.string_input_producer(files, shuffle=False) 
#读取并解析一个样本
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
      serialized_example,
      features={
          'i': tf.FixedLenFeature([], tf.int64),
          'j': tf.FixedLenFeature([], tf.int64),
      })
with tf.Session() as sess:
    #虽然本段程序没有声明一些变量,但使用tf.train.match_filenames_once函数时需要初始化一些变量
    sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
    #打印文件列表
    print(sess.run(files))
    #声明tf.train.Coordinator类来协同不同线程,并启动线程
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    #多次执行获取数据
    for i in range(6):
        print(sess.run([features['i'], features['j']]))
    coord.request_stop()
    coord.join(threads)
    注意这里需要tf.local_variables_initializer()来初始化tf.train.match_filenames_once中的变量。并且tf.initialize_all_variables().run()函数于2017-03-02删除了。得到如下
[b'Records\\data.tfrecords-00000-of-00002'
 b'Records\\data.tfrecords-00001-of-00002']
[0, 0]
[0, 1]
[1, 0]
[1, 1]
[0, 0]
[0, 1]

最后俩个结果是第一个文件的第二遍输出,故而可知tf.train.string_input_producer函数会将初始文件列表重新加入队列。
        tf.train.string_input_producer函数会使用初始化时提供的文件列表创建一个输入队列,创建好的输入队列可以作为文件读取函数的参数。每次调用文件读取函数,该函数会先判断当前是否有已打开的文件可读,如果没有或者打开的文件已经读完,这个函数会从输入队列出队一个文件并且读取他。通过shuffle参数是否来打乱文件列表中出队顺序,为True时会被打乱。随机打乱文件顺序及加入输入队列的过程会跑在一个单独的线程上,这样不会影响读取文件的速度,且其生成的输入队列可以同时被多个文件读取线程操作,其输入队列还会将队列中的文件均匀的分给不同线程,不会有有些文件被处理多次二有些文件还没有被处理的情况。tf.train.string_input_producer可以通过设置num_epochs参数来限制加载初始文件列表的最大轮数。当所有文件被使用设定轮数,如果继续尝试读取新文件,输入队列会报OutOfRange的错误。一般参数设为1.