blog
torch.utils.data.Dataset
- create dataset with class torch.utils.data.Dataset automaticly
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
# 根据索引获取样本
return self.data[index]
def __len__(self):
# 返回数据集大小
return len(self.data)
# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 根据索引获取样本
sample = dataset[2]
print(sample)
torchvision.datasets
- load data from classic dataset
import torch
from torchvision import datasets, transforms
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5,), (0.5,)) # 标准化图像
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
2. load data from Imagefolder with transform
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# transform.Compose是PyTorch中的一个类,用于将多个图像变换操作组合在一起。它的作用是将这些操作按照顺序依次应用于输入的图像数据。
trans = transforms.Compose([
np.float32,
transforms.ToTensor(),
fixed_image_standardization
])
dataset = datasets.ImageFolder(data_dir, transform=trans)
loader = DataLoader(
dataset,
num_workers=workers,
batch_size=batch_size,
collate_fn=training.collate_pil
)
3. Introduction of Imagefolder
# 定义输入图像的数据加载器
dataset = datasets.ImageFolder(data_dir, transform=None)
# img_list_1=[img for (img,idx) in dataset.imgs]
# with open("img_list_1.pkl","wb") as file:
# pickle.dump(img_list_1,file)
print(dataset)
print(len(dataset))
print(len(dataset.imgs))
print(len(dataset.classes))
print(dataset.classes[-1])
print(dataset.classes)
print(dataset.imgs)
\root
\cls1
\img1.png
\img2.png
\cls2
\img1.png
\img2.png
\cls3
\img1.png
\img2.png
DataLoader
- to loader data from example of torch.utils.data.Dataset
import torch
from torchvision import datasets, transforms
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
# 使用数据加载器迭代样本
for images, labels in train_loader:
# 训练模型的代码
...
torchvision.transforms
from torchvision import transforms
# 定义图像预处理操作
transform = transforms.Compose([
transforms.Resize((256, 256)), # 缩放图像大小为 (256, 256)
transforms.RandomCrop((224, 224)), # 随机裁剪图像为 (224, 224)
transforms.RandomHorizontalFlip(), # 随机水平翻转图像
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化图像
])
# 对图像进行预处理
image = transform(image)