Tensorflow 数据预读取--Queue

时间:2022-03-21 17:36:52

Google开源的深度学习框架Tensorflow在数据预取上做了一些特殊的特征来提高模型训练或者推理的效率,避免在IO上耗费过多的时间。本文通过几个简单例子介绍Tensorflow构建queue常用函数的使用方法。

    深度学习训练模型通常是建立在大数据基础上,一般情况下可以把数据都加载到内存避免训练时数据读取IO。但是,当数据占用空间较大,如图片集或者视频集,无法全部载入内存;另一种方式是在训练时再读取需要的数据,但是增加的IO耗时会让模型训练过程很漫长很漫长。

    Tensorflow提供了Queue这个工具来更好的解决这类问题。Queue构建了一个大小为capacity的缓存区,多线程执行数据的enqueue,神经网络模型从缓存区dequeue数据。如果capacity足够大,数据的加载和读取可以同时执行,没有阻塞,从而IO的时间几乎可以忽略不计。

slice_input_producer

过程描述:图片数据保存在本地,内存中保存所有图片的系统路径,现在构建Queue,从磁盘上读取并缓存数据。整个过程类似于:

Tensorflow 数据预读取--Queue

def slice_input_producer_demo(image_pair_path, summary_path):
# 重置graph
tf.reset_default_graph()
# 获取<图片一系统路径,图片二系统路径,标签信息>三个list(load_data函数见supplementary)
image_one_path_list, image_two_path_list, label_list = load_data()
## 构造数据queue
train_input_queue = tf.train.slice_input_producer([image_one_path_list, image_two_path_list, label_list], capacity=10 * batch_size)

## queue输出数据
img_one_queue = get_image(train_input_queue[0])
img_two_queue = get_image(train_input_queue[1])
label_queue = train_input_queue[2]

## shuffle_batch批量从queu批量读取数据
batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch([img_one_queue, img_two_queue, label_queue],batch_size=batch_size,capacity = 10 + 10* batch_size,min_after_dequeue = 10,num_threads=16,shapes=[(image_width, image_height, image_channel),(image_width, image_height, image_channel),()])

sess = tf.Session()
sess.run(tf.initialize_all_variables())

summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)

## 启动queue线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)

for i in range(10):
batch_img_one_val, batch_img_two_val, label = sess.run([batch_img_one, batch_img_two,batch_label])
for k in range(batch_size):
fig = plt.figure()
fig.add_subplot(1,2,1)
plt.imshow(batch_img_one_val[k])
fig.add_subplot(1,2,2)
plt.imshow(batch_img_two_val[k])
plt.show()


coord.request_stop()
coord.join(threads)
sess.close()
summary_writer.close()

整个过程很清晰,主要由以下几步组成:
1、图片的路径和标记信息载入内存:image_one_path_list, image_two_path_list, label_list = load_data()
2、构造第一个queue:train_input_queue = tf.train.slice_input_producer( [image_one_path_list, image_two_path_list, label_list], capacity=10 * batch_size)
3、从queue取出图片路径数据加载图片:img_one_queue = get_image(train_input_queue[0])
4、构造第二个queue:shuffle_queue,把图片数据enqueue到缓存区,批量dequeue输出结果。batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch([img_one_queue, img_two_queue, label_queue]...)

string_input_producer

string_input_producer从一个pipeline把字符串输出到一个queue。

def string_input_producer_demo(image_pair_path, summary_path):
tf.reset_default_graph()

image_one_path_list, image_two_path_list, label_list = load_data()
## 构造数据queue
train_input_queue = tf.train.string_input_producer(image_one_path_list, capacity=10 * batch_size)

## queue输出数据
img_one_queue = get_image(train_input_queue.dequeue())

sess = tf.Session()
sess.run(tf.initialize_all_variables())
summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)

## queue线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)

for i in range(10):
img_one_val = sess.run([img_one_queue])
fig = plt.figure()
plt.imshow(img_one_val[0])
plt.show()


coord.request_stop()
coord.join(threads)
sess.close()
summary_writer.close()

range_input_producer:生成0到limit-1的queue

def range_input_producer_demo(image_pair_path, summary_path):
tf.reset_default_graph()

image_one_path_list, image_two_path_list, label_list = load_data()
length_data = len(image_one_path_list)

image_one_path_list = tf.convert_to_tensor(image_one_path_list)
image_two_path_list = tf.convert_to_tensor(image_two_path_list)
label_list = tf.convert_to_tensor(label_list)

## 构造数据queue
train_input_queue = tf.train.range_input_producer(length_data, capacity=10 * batch_size)

## queue输出数据
range_index = train_input_queue.dequeue()
img_one_queue = get_image(tf.gather(image_one_path_list, range_index))
img_two_queue = get_image(tf.gather(image_two_path_list, range_index))
label_queue = range_index

## 批量从queu读取数据
batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch([img_one_queue, img_two_queue, label_queue],batch_size=batch_size,capacity = 10 + 10* batch_size,min_after_dequeue = 10,num_threads=16,shapes=[(image_width, image_height, image_channel),(image_width, image_height, image_channel),()])

sess = tf.Session()
sess.run(tf.initialize_all_variables())

summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)

## queue线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)

for i in range(10):
batch_img_one_val, batch_img_two_val, label = sess.run([batch_img_one, batch_img_two,batch_label])
for k in range(batch_size):
fig = plt.figure()
fig.add_subplot(1,2,1)
plt.imshow(batch_img_one_val[k])
fig.add_subplot(1,2,2)
plt.imshow(batch_img_two_val[k])
plt.show()


coord.request_stop()
coord.join(threads)
sess.close()
summary_writer.close()

input_producer:input_tensor里的行构成queue

def input_producer_demo(image_pair_path, summary_path):
tf.reset_default_graph()

image_one_path_list, image_two_path_list, label_list = load_data()
length_data = len(image_one_path_list)

image_one_path_list = tf.convert_to_tensor(image_one_path_list)

## 构造数据queue
train_input_queue = tf.train.input_producer(image_one_path_list, capacity=10 * batch_size)

## Expected string, got <tensorflow.python.ops.data_flow_ops.FIFOQueue object of type 'FIFOQueue' instead.
img_one_queue = get_image(train_input_queue.dequeue())

sess = tf.Session()
sess.run(tf.initialize_all_variables())

summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)

## queue线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)

for i in range(10):
batch_img_one_val = sess.run([img_one_queue])

# for k in range(batch_size):
print(len(batch_img_one_val))
fig = plt.figure()
plt.imshow(batch_img_one_val[0])
plt.show()


coord.request_stop()
coord.join(threads)
sess.close()
summary_writer.close()

supplementary

数据格式:
        /home/Alex/4000.jpg /home/Alex/4001.jpg 0
        /home/Alex/4000.jpg /home/Alex/4002.jpg 1

# 获取《图片一本地路径,图片二本地路径,标记》数据对
def load_data():
reader_handler = open(image_pair_path, 'r')

image_one_path_list = []
image_two_path_list = []
label_list = []

count = 0
for line in reader_handler:
count = count + 1
elems = line.split("\t")
if len(elems) < 3:
print("len(elems) < 3:" + line)
continue
image_one_path = elems[0].strip()
image_two_path = elems[1].strip()
label = int(elems[2].strip())

image_one_path_list.append(image_one_path)
image_two_path_list.append(image_two_path)
label_list.append(label)

return image_one_path_list, image_two_path_list, label_list


# 根据图片路径读取图片
def get_image(image_path):
"""Reads the jpg image from image_path.
Returns the image as a tf.float32 tensor
Args:
image_path: tf.string tensor
Reuturn:
the decoded jpeg image casted to float32
"""

content = tf.read_file(image_path)
tf_image = tf.image.decode_jpeg(content, channels=3)

return tf_image