一、train_on_batch
1
|
model.train_on_batch(batchX, batchY)
|
train_on_batch函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型,大部分情况下我们不需要这么精细,99%情况下使用fit_generator训练方式即可,下面会介绍。
二、fit
1
|
model.fit(x_train, y_train, batch_size = 32 , epochs = 10 )
|
fit的方式是一次把训练数据全部加载到内存中,然后每次批处理batch_size个数据来更新模型参数,epochs就不用多介绍了。这种训练方式只适合训练数据量比较小的情况下使用。
三、fit_generator
利用Python的生成器,逐个生成数据的batch并进行训练,不占用大量内存,同时生成器与模型将并行执行以提高效率。例如,该函数允许我们在CPU上进行实时的数据提升,同时在GPU上进行模型训练
接口如下:
1
|
fit_generator( self , generator, steps_per_epoch, epochs = 1 , verbose = 1 , callbacks = None , validation_data = None , validation_steps = None , class_weight = None , max_q_size = 10 , workers = 1 , pickle_safe = False , initial_epoch = 0 )
|
generator
:生成器函数
steps_per_epoch
:整数,当生成器返回steps_per_epoch次数据时,计一个epoch结束,执行下一个epoch。也就是一个epoch下执行多少次batch_size。
epochs
:整数,控制数据迭代的轮数,到了就结束训练。
callbacks=None, list,list中的元素为keras.callbacks.Callback对象,在训练过程中会调用list中的回调函数
举例:
1
2
3
4
5
6
7
8
9
10
11
|
def generate_arrays_from_file(path):
while True :
with open (path) as f:
for line in f:
# create numpy arrays of input data
# and labels, from each line in the file
x1, x2, y = process_line(line)
yield ({ 'input_1' : x1, 'input_2' : x2}, { 'output' : y})
model.fit_generator(generate_arrays_from_file( './my_folder' ),
steps_per_epoch = 10000 , epochs = 10 )
|
补充:keras.fit_generator()属性及取值
如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
|
fit_generator( self , generator,
steps_per_epoch = None ,
epochs = 1 ,
verbose = 1 ,
callbacks = None ,
validation_data = None ,
validation_steps = None ,
class_weight = None ,
max_queue_size = 10 ,
workers = 1 ,
use_multiprocessing = False ,
shuffle = True ,
initial_epoch = 0 )
|
通过Python generator产生一批批的数据用于训练模型。generator可以和模型并行运行,例如,可以使用CPU生成批数据同时在GPU上训练模型。
参数:
generator
:一个generator或Sequence实例,为了避免在使用multiprocessing时直接复制数据。
steps_per_epoch
:从generator产生的步骤的总数(样本批次总数)。通常情况下,应该等于数据集的样本数量除以批量的大小。
epochs
:整数,在数据集上迭代的总数。
works
:在使用基于进程的线程时,最多需要启动的进程数量。
use_multiprocessing
:布尔值。当为True时,使用基于基于过程的线程。
例如:
1
2
3
4
5
6
|
datagen = ImageDataGenator(...)
model.fit_generator(datagen.flow(x_train, y_train,
batch_size = batch_size),
epochs = epochs,
validation_data = (x_test, y_test),
workers = 4 )
|
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://www.cnblogs.com/gczr/p/12380887.html