【pytorch】土堆pytorch教程学习(五)torchvision 中的数据集的使用
torchvision 中的数据集使用
在torchvision.datasets
模块中提供了许多内置的数据集。
内置的数据集有 CIFAR10、MNIST、COCO等,更多可进入 pytorch 官网查看。
所有内置的数据集都继承了 torch.utils.data.Dataset
类,并且实现了 __getitem__
和 __len__
。
所有的数据集几乎都有相似的API。下面以 CIFAR10
数据集的使用为例来认识下内置数据集的用法。
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
'''
dataset = torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
Args:
root(string):数据集存放的根目录。
train(bool):如果True则从训练集创建数据集,False则从测试集创建数据集。
transform(callable):需要对图像进行的转换操作
target_transforms(callable):需要对 target 进行的转换操作
download(bool):True则从网络下载数据集到根目录。如果数据集已经存在,则不再下载。
'''
# 创建训练集
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transform, download=True)
# 创建测试集
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)
img, target = test_set[0] # 取出图像和target
print(img, test_set.classes[target])
# 在 tensorboard 里打开十张图像
writer = SummaryWriter('logs')
for i in range(10):
img, target = test_set[i]
writer.add_image('test_set', img, i)
writer.close()
内置数据集很方便地供我们下载使用。根据源码或者官方文档可以了解到创建数据集所需传入的参数,然后需要关注__getitem__
方法返回的结果是什么。
自定义数据集
自己定义的数据集可以参照内置数据集,即继承 torch.utils.data.Dataset
类,并且重写 __getitem__
和 __len__
。
数据存放在 dataset/train
里,分为两个目录 ants
和 bees
,也分别是数据的标签,如下图所示:
from PIL import Image
from torch.utils.data import Dataset
import os
class MyDataSet(Dataset):
# 在__init__里加载
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 根目录
self.label_dir = label_dir # 标签目录
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path) # 图片路径列表
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.path, img_name)
img = Image.open(img_item_path) # 获取数据
label = self.label_dir # 获取label
return img, label
def __len__(self):
return len(self.img_path) # 获取数据集长度
# test
root_dir = 'dataset/train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'
# 生成两个数据集
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset # 拼接两个数据集
img1, label1 = ants_dataset[0]
img1.show()
print('label1:', label1)
img2, label2 = train_dataset[130]
img2.show()
print('label2:', label2)