1、简介
()
函数是 PyTorch 中用于创建数据加载器(data loader)的函数。数据加载器用于加载训练和测试数据集,并将数据划分为小批量进行处理。它们是数据处理流程中的关键组件,可以方便地进行数据批量处理、乱序加载和并行读取。
2、函数原型&参数含义
(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False)
的参数含义如下:
-
dataset
:数据集对象,可以是自定义的Dataset
类对象或 PyTorch 提供的预定义数据集,如。
-
batch_size
:批量大小,即每个小批量的样本数量。默认值是 1。 -
shuffle
:是否在每个 epoch 重新洗牌(随机打乱数据)。默认值是 False。 -
num_workers
:用于数据加载的子进程数量。默认值是 0,即在主进程中进行数据加载。 -
collate_fn
:用于将样本列表转换为小批量数据张量的函数。默认值是 None,表示使用默认的数据处理方式。 -
pin_memory
:是否将数据存储在 CUDA 固定内存中,以加速数据传输。默认值是 False。 -
drop_last
:当样本数不能被批量大小整除时,是否删除最后一个小于批量大小的小批量。默认值是 False,即保留部分小于批量大小的样本。