【pytorch】土堆pytorch教程学习(六)DataLoader 的使用
DataLoader将数据集(dataset)和采样器(sampler)组合在一起,并在给定数据集上提供迭代。
DataLoader支持 map 式和 iterable 式的数据集,可进行单进程或多进程加载、自定义加载顺序和可选的自动批处理和内存固定。
先看下实例化一个 DataLoader 所需的参数:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
multiprocessing_context=None, generator=None, *,
prefetch_factor=None, persistent_workers=False, pin_memory_device='')
只需关注几个常用的参数即可,剩下的可以慢慢了解:
dataset:要从中加载数据的数据集batch_size:每批要加载的样本大小,默认为1shuffle:设置为True将在每个 epoch(新阶段)重组数据(默认值:False)num_workers:用于数据加载的子进程数。0表示将在主进程中加载数据。(默认值: 0)drop_last:如果数据集大小不能被batch_size整除,设置为True可删除最后一个不完整的批处理。sampler:定义从数据集中提取样本的策略。可以是实现了__len__的任何 Iterable。如果指定了,则shuffle必须为False。
一般 PyTorch 中深度学习训练的流程如下:
- 创建Dateset
- Dataset传递给DataLoader
- DataLoader迭代产生训练数据提供给模型
# 创建 Dateset
test_set = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
# Dataset 传递给 DataLoader
dataloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# DataLoader 迭代产生训练数据提供给模型
for i in range(epoch):
for index,(imgs,targets) in enumerate(dataloader):
pass
Dataset 负责建立索引到样本的映射,DataLoader 负责以特定的方式从数据集中迭代的产生一个个 batch 的样本集合。在 enumerate 过程中实际上是 dataloader 按照其参数 sampler 规定的策略调用了其 dataset 的 __getitem__方法。