先放关键代码:
1
2
|
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs = 1 , shuffle = False ).dequeue()
inputs = tf. slice (array, [i * BATCH_SIZE], [BATCH_SIZE])
|
原理解析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
main.py内容:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
|
import tensorflow as tf
import codecs
BATCH_SIZE = 6
NUM_EXPOCHES = 5
def input_producer():
array = codecs. open ( "test.txt" ).readlines()
array = map ( lambda line: line.strip(), array)
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs = 1 , shuffle = False ).dequeue()
inputs = tf. slice (array, [i * BATCH_SIZE], [BATCH_SIZE])
return inputs
class Inputs( object ):
def __init__( self ):
self .inputs = input_producer()
def main( * args, * * kwargs):
inputs = Inputs()
init = tf.group(tf.initialize_all_variables(),
tf.initialize_local_variables())
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
sess.run(init)
try :
index = 0
while not coord.should_stop() and index< 10 :
datalines = sess.run(inputs.inputs)
index + = 1
print ( "step: %d, batch data: %s" % (index, str (datalines)))
except tf.errors.OutOfRangeError:
print ( "Done traing:-------Epoch limit reached" )
except KeyboardInterrupt:
print ( "keyboard interrput detected, stop training" )
finally :
coord.request_stop()
coord.join(threads)
sess.close()
del sess
if __name__ = = "__main__" :
main()
|
输出:
1
2
3
4
5
6
|
step: 1 , batch data: [ '1' '2' '3' '4' '5' '6' ]
step: 2 , batch data: [ '7' '8' '9' '10' '11' '12' ]
step: 3 , batch data: [ '13' '14' '15' '16' '17' '18' ]
step: 4 , batch data: [ '19' '20' '21' '22' '23' '24' ]
step: 5 , batch data: [ '25' '26' '27' '28' '29' '30' ]
Done traing: - - - - - - - Epoch limit reached
|
如果range_input_producer去掉参数num_epochs=1,则输出:
1
2
3
4
5
6
7
8
9
10
|
step: 1 , batch data: [ '1' '2' '3' '4' '5' '6' ]
step: 2 , batch data: [ '7' '8' '9' '10' '11' '12' ]
step: 3 , batch data: [ '13' '14' '15' '16' '17' '18' ]
step: 4 , batch data: [ '19' '20' '21' '22' '23' '24' ]
step: 5 , batch data: [ '25' '26' '27' '28' '29' '30' ]
step: 6 , batch data: [ '1' '2' '3' '4' '5' '6' ]
step: 7 , batch data: [ '7' '8' '9' '10' '11' '12' ]
step: 8 , batch data: [ '13' '14' '15' '16' '17' '18' ]
step: 9 , batch data: [ '19' '20' '21' '22' '23' '24' ]
step: 10 , batch data: [ '25' '26' '27' '28' '29' '30' ]
|
有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:
1
2
|
InvalidArgumentError (see above for traceback): Expected size[ 0 ] in [ 0 , 5 ], but got 6
[[Node: Slice = Slice [Index = DT_INT32, T = DT_STRING, _device = "/job:localhost/replica:0/task:0/cpu:0" ]( Slice / input , Slice / begin / _5, Slice / size)]]
|
错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。
以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/lyg5623/article/details/69387917