root本地文件夹里有,则从本地读;没有的话,如指定了ownload=True,则从远程下载;
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
Dataset类:通过index,拿到1条数据;
数据可以都在磁盘上,用到哪条,就加载哪条;
自定义一个类,需要继承Dataset类,并重写__init__、__len__、__getitem__
DataLoader类:batching, shuffle(sampling策略), multiprocess加载,pin memory,...
ToTensor(): 把PIL格式的Image,转成Tensor;
Lambda: 把int的y,转成10维度的1-hot向量;