PyTorch使用细节

时间:2024-07-15 10:44:23

root本地文件夹里有,则从本地读;没有的话,如指定了ownload=True,则从远程下载;

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

Dataset类:通过index,拿到1条数据;

        数据可以都在磁盘上,用到哪条,就加载哪条;

        自定义一个类,需要继承Dataset类,并重写__init__、__len__、__getitem__

DataLoader类:batching, shuffle(sampling策略), multiprocess加载,pin memory,...

ToTensor(): 把PIL格式的Image,转成Tensor;

Lambda: 把int的y,转成10维度的1-hot向量;