本文介绍了tensorflow中next_batch的具体使用,分享给大家,具体如下:
此处给出了几种不同的next_batch方法,该文章只是做出代码片段的解释,以备以后查看:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
def next_batch( self , batch_size, fake_data = False ):
"""Return the next `batch_size` examples from this data set."""
if fake_data:
fake_image = [ 1 ] * 784
if self .one_hot:
fake_label = [ 1 ] + [ 0 ] * 9
else :
fake_label = 0
return [fake_image for _ in xrange (batch_size)], [
fake_label for _ in xrange (batch_size)
]
start = self ._index_in_epoch
self ._index_in_epoch + = batch_size
if self ._index_in_epoch > self ._num_examples: # epoch中的句子下标是否大于所有语料的个数,如果为True,开始新一轮的遍历
# Finished epoch
self ._epochs_completed + = 1
# Shuffle the data
perm = numpy.arange( self ._num_examples) # arange函数用于创建等差数组
numpy.random.shuffle(perm) # 打乱
self ._images = self ._images[perm]
self ._labels = self ._labels[perm]
# Start next epoch
start = 0
self ._index_in_epoch = batch_size
assert batch_size < = self ._num_examples
end = self ._index_in_epoch
return self ._images[start:end], self ._labels[start:end]
|
该段代码摘自mnist.py文件,从代码第12行start = self._index_in_epoch开始解释,_index_in_epoch-1是上一次batch个图片中最后一张图片的下边,这次epoch第一张图片的下标是从 _index_in_epoch开始,最后一张图片的下标是_index_in_epoch+batch, 如果 _index_in_epoch 大于语料中图片的个数,表示这个epoch是不合适的,就算是完成了语料的一遍的遍历,所以应该对图片洗牌然后开始新一轮的语料组成batch开始
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
def ptb_iterator(raw_data, batch_size, num_steps):
"""Iterate on the raw PTB data.
This generates batch_size pointers into the raw PTB data, and allows
minibatch iteration along these pointers.
Args:
raw_data: one of the raw data outputs from ptb_raw_data.
batch_size: int, the batch size.
num_steps: int, the number of unrolls.
Yields:
Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
The second element of the tuple is the same data time-shifted to the
right by one.
Raises:
ValueError: if batch_size or num_steps are too high.
"""
raw_data = np.array(raw_data, dtype = np.int32)
data_len = len (raw_data)
batch_len = data_len / / batch_size #有多少个batch
data = np.zeros([batch_size, batch_len], dtype = np.int32) # batch_len 有多少个单词
for i in range (batch_size): # batch_size 有多少个batch
data[i] = raw_data[batch_len * i:batch_len * (i + 1 )]
epoch_size = (batch_len - 1 ) / / num_steps # batch_len 是指一个batch中有多少个句子
#epoch_size = ((len(data) // model.batch_size) - 1) // model.num_steps # // 表示整数除法
if epoch_size = = 0 :
raise ValueError( "epoch_size == 0, decrease batch_size or num_steps" )
for i in range (epoch_size):
x = data[:, i * num_steps:(i + 1 ) * num_steps]
y = data[:, i * num_steps + 1 :(i + 1 ) * num_steps + 1 ]
yield (x, y)
|
第三种方式:
1
2
3
4
5
6
7
8
9
10
11
12
13
|
def next ( self , batch_size):
""" Return a batch of data. When dataset end is reached, start over.
"""
if self .batch_id = = len ( self .data):
self .batch_id = 0
batch_data = ( self .data[ self .batch_id: min ( self .batch_id +
batch_size, len ( self .data))])
batch_labels = ( self .labels[ self .batch_id: min ( self .batch_id +
batch_size, len ( self .data))])
batch_seqlen = ( self .seqlen[ self .batch_id: min ( self .batch_id +
batch_size, len ( self .data))])
self .batch_id = min ( self .batch_id + batch_size, len ( self .data))
return batch_data, batch_labels, batch_seqlen
|
第四种方式:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
def batch_iter(sourceData, batch_size, num_epochs, shuffle = True ):
data = np.array(sourceData) # 将sourceData转换为array存储
data_size = len (sourceData)
num_batches_per_epoch = int ( len (sourceData) / batch_size) + 1
for epoch in range (num_epochs):
# Shuffle the data at each epoch
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = sourceData[shuffle_indices]
else :
shuffled_data = sourceData
for batch_num in range (num_batches_per_epoch):
start_index = batch_num * batch_size
end_index = min ((batch_num + 1 ) * batch_size, data_size)
yield shuffled_data[start_index:end_index]
|
迭代器的用法,具体学习Python迭代器的用法
另外需要注意的是,前三种方式只是所有语料遍历一次,而最后一种方法是,所有语料遍历了num_epochs次
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:http://blog.csdn.net/appleml/article/details/57413615