一、几个函数
- RandomShuffleQueue类
__init__(self, capacity, min_after_dequeue,dtypes, shapes=None,names=None, seed=None, shared_name=None, name="random_shuffle_queue")
queue = tf.RandomShuffleQueue(...):创建一个queue,按随机顺序进行dequeue
RandomShuffleQueue有一定的容量限制capacity,支持多个生产者和消费者
RandomShuffleQueue中的每个元素是固定长度的tensor 元组,数据类型由dtypes定义,形状为shapes。如果shapes没有定义,那么不同的queue元素可能有不同的形状,此时就不能使用dqueue_many。如果shapes定义了,则所有的元素必须有相同的形状
min_after_dequeue决定queue在dequeue以后要保持的元素个数,如果没有足够的元素,就会block住dequeue的相关操作,直到有足够元素进来。当queue关闭,则这个参数被忽略
- enqueue(self, vals, name=None)
enqueue_op = queue.enqueue(...) 创建enqueue元素到queue中的操作
如果操作执行时queue是满的,则会block住
vals是一个tensor或一个tensor的list/tuple,或者是一个字典,它相当于enqueue操作时的数据池
enqueue操作是要手动触发的,也就是不是说像一般的那种计算,会把enqueue作为依赖操作被执行
- queue.dequeue(self, name=None)
从queue中取出一个元素
- Coordinator类
__init__(self, clean_stop_exception_types=None)
coord = tf.Coordinator() 协调线程的执行
- QueueRunner类
__init__(self, queue=None, enqueue_ops=None, close_op=None, cancel_op=None, queue_closed_exception_types=None,queue_runner_def=None, import_scope=None)
说明
qr = tf.train.QueueRunner(...) 为一个queue保持一系列enqueue操作,每个操作以一个线程执行
queue: a Queue
enqueue_ops: 一个enqueue ops列表
close_op: 指定关闭queue的操作
cancel_op:指定关闭以及取消挂起的enqueue ops的操作
- qr.create_threads(self, sess, coord=None, daemon=False, start=False)
为给定的sess创建多个线程以执行enqueue ops
start:如果为False,则需要手动调用 start()来启动
- start_queue_runners
start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection=ops.GraphKeys.QUEUE_RUNNERS)
tf.train.start_queue_runners(...) 启动图中所有的queue runners,与add_queue_runner()配合使用
start: `False`只是创建线程,但是没有启动
二、实例
def example1(): """
最简单的例子,只使用enqueue和dequeue
:return:
"""
example = tf.constant(2, "float32", [2, 2])
# 创建一个queue
# tf.RandomShuffleQueue(capacity,: queue的容量
# min_after_dequeue, : 保证queue中最少的个数
# dtypes,
# shapes=None,...)
queue = tf.RandomShuffleQueue(10, 0, "float32", shapes=[2, 2])
# 为queue添加enqueue操作
enqueue_op = queue.enqueue(example)
# 为queue添加dequeue操作
inputs = queue.dequeue()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(enqueue_op)
print(sess.run(inputs))
def example2():
"""
使用queue runner来管理多个enqueue线程,用coord来关闭线程
:return:
"""
data = tf.constant(2, "float32", [2, 2])
example = [data, data, data, data, data, data, data, data]
queue = tf.RandomShuffleQueue(10, 0, "float32", shapes=[2, 2])
enqueue_op = queue.enqueue(example) qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
coord = tf.train.Coordinator() inputs = queue.dequeue()
with tf.Session() as sess:
threads = qr.create_threads(sess, coord, start=True)
sess.run(tf.global_variables_initializer())
print(sess.run(inputs))
# 用coord来停止所有的enqueu线程
coord.request_stop()
coord.join(threads)