Pytorch 用法详细介绍

时间:2025-03-30 20:15:20
  • dataset (Dataset) – 加载的数据集

  • batch_size (int, optional) – 每一次处理加载多少数据

  • shuffle (bool, optional) – True 表示每次 epoch 都要重新打乱数据,默认 False

  • sampler (Sampler or Iterable, optional) – 定义采样的策略。如果定义了此参数,那么 shuffle 参数必须为 False

  • batch_sampler (Sampler or Iterable, optional) – 同 sample 一样,但每次返回数据的索引。与 batch_sizeshufflesampledrop_last 参数互斥

  • num_workers (int, optional) – 指定用于数据加载的子进程数,可以加快数据加载速度。默认0,表示用主进程加载

  • collate_fn (Callable, optional) – 批处理函数,用于将多个样本合并成一个批次,例如将多个张量拼接在一起,构建 mini-batch。当使用 map-style 数据集进行批量加载时使用。

  • pin_memory (bool, optional) – True 表示在返回张量之前将张量复制到 CUDA 固定的内存中,加快 GPU 传输速度

  • drop_last (bool, optional) – True 表示可删除最后一个不完整的批次。默认 False,如果数据集的大小不能被批次大小整除,则最后一个批次会更小。

  • timeout (numeric, optional) – 非负数,worker 收集批次数据的超时时间,默认0

  • worker_init_fn (Callable, optional) – 如果非None,则在种子设定之后和数据加载之前,将以worker id([0,num_workers-1]中的int)作为输入对每个 worker 子进程调用此函数。(默认值:None)

  • multiprocessing_context (str or , optional) – 如果为None,则将使用操作系统的默认多处理上下文。(默认值:None)

  • generator (, optional) – 如果非None,则RandomSampler 将使用此RNG来生成随机索引,并进行多进程处理以为 workers 生成 base_seed。(默认值:None)

  • prefetch_factor (int, optional, keyword-only arg) – 每个 worker 预先装载的批次数。2 表示在所有工作线程中总共预取2*num_workers批次。(默认值取决于为num_workers设置的值。如果num_workers=0的值,则默认为None。否则,如果num_workers>0的值,默认为2)

  • persistent_workers (bool, optional) – True 表示不会在数据集使用一次后关闭工作进程。这允许保持 worker 实例处于活动状态。(默认值:False)

  • pin_memory_device (str, optional) – 如果 pin_memory 为 True,该参数表示 pin_memory 所指向的设备