【pytorch】土堆pytorch教程学习(六)DataLoader 的使用

hzyuan / 2023-05-04 / 原文

DataLoader 将数据集(dataset)和采样器(sampler)组合在一起,并在给定数据集上提供迭代。

DataLoader 支持 map 式和 iterable 式的数据集,可进行单进程或多进程加载、自定义加载顺序和可选的自动批处理和内存固定。

先看下实例化一个 DataLoader 所需的参数:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None,
                            batch_sampler=None, num_workers=0, collate_fn=None, 
                            pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
                            multiprocessing_context=None, generator=None, *, 
                            prefetch_factor=None, persistent_workers=False, pin_memory_device='')

只需关注几个常用的参数即可,剩下的可以慢慢了解:

  • dataset:要从中加载数据的数据集
  • batch_size:每批要加载的样本大小,默认为1
  • shuffle:设置为 True 将在每个 epoch(新阶段)重组数据(默认值:False)
  • num_workers:用于数据加载的子进程数。0表示将在主进程中加载数据。(默认值: 0)
  • drop_last:如果数据集大小不能被 batch_size 整除,设置为 True 可删除最后一个不完整的批处理。
  • sampler:定义从数据集中提取样本的策略。可以是实现了 __len__ 的任何 Iterable。如果指定了,则 shuffle 必须为 False

一般 PyTorch 中深度学习训练的流程如下:

  1. 创建Dateset
  2. Dataset传递给DataLoader
  3. DataLoader迭代产生训练数据提供给模型
# 创建 Dateset
test_set = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
# Dataset 传递给 DataLoader
dataloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# DataLoader 迭代产生训练数据提供给模型
for i in range(epoch):
    for index,(imgs,targets) in enumerate(dataloader):
        pass

Dataset 负责建立索引到样本的映射,DataLoader 负责以特定的方式从数据集中迭代的产生一个个 batch 的样本集合。在 enumerate 过程中实际上是 dataloader 按照其参数 sampler 规定的策略调用了其 dataset 的 __getitem__方法。