【torch】数据加载器DataLoader

时间:2025-03-30 20:14:30

1、简介

() 函数是 PyTorch 中用于创建数据加载器(data loader)的函数。数据加载器用于加载训练和测试数据集,并将数据划分为小批量进行处理。它们是数据处理流程中的关键组件,可以方便地进行数据批量处理、乱序加载和并行读取。

2、函数原型&参数含义

(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False) 的参数含义如下:

  • dataset:数据集对象,可以是自定义的 Dataset 类对象或 PyTorch 提供的预定义数据集,如 
  • batch_size:批量大小,即每个小批量的样本数量。默认值是 1。
  • shuffle:是否在每个 epoch 重新洗牌(随机打乱数据)。默认值是 False。
  • num_workers:用于数据加载的子进程数量。默认值是 0,即在主进程中进行数据加载。
  • collate_fn:用于将样本列表转换为小批量数据张量的函数。默认值是 None,表示使用默认的数据处理方式。
  • pin_memory:是否将数据存储在 CUDA 固定内存中,以加速数据传输。默认值是 False。
  • drop_last:当样本数不能被批量大小整除时,是否删除最后一个小于批量大小的小批量。默认值是 False,即保留部分小于批量大小的样本。

相关文章