Pytorch for training1——read data/image

时间:2024-03-31 07:05:57

blog

torch.utils.data.Dataset

  1. create dataset with class torch.utils.data.Dataset automaticly
import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        # 根据索引获取样本
        return self.data[index]

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 根据索引获取样本
sample = dataset[2]
print(sample)

torchvision.datasets

  1. load data from classic dataset
import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

2. load data from Imagefolder with transform

from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# transform.Compose是PyTorch中的一个类,用于将多个图像变换操作组合在一起。它的作用是将这些操作按照顺序依次应用于输入的图像数据。
trans = transforms.Compose([
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

dataset = datasets.ImageFolder(data_dir, transform=trans)
loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    collate_fn=training.collate_pil
)

在这里插入图片描述

3. Introduction of Imagefolder

# 定义输入图像的数据加载器
dataset = datasets.ImageFolder(data_dir, transform=None)
# img_list_1=[img for (img,idx) in dataset.imgs]
# with open("img_list_1.pkl","wb") as file:
#     pickle.dump(img_list_1,file)
print(dataset)
print(len(dataset))
print(len(dataset.imgs))
print(len(dataset.classes))
print(dataset.classes[-1])
print(dataset.classes)
print(dataset.imgs)

在这里插入图片描述

\root
	\cls1
		\img1.png
		\img2.png
	\cls2
		\img1.png
		\img2.png
	\cls3
		\img1.png
		\img2.png

DataLoader

  1. to loader data from example of torch.utils.data.Dataset
import torch
from torchvision import datasets, transforms

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# 使用数据加载器迭代样本
for images, labels in train_loader:
    # 训练模型的代码
    ...

torchvision.transforms

from torchvision import transforms

# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 缩放图像大小为 (256, 256)
    transforms.RandomCrop((224, 224)),  # 随机裁剪图像为 (224, 224)
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化图像
])

# 对图像进行预处理
image = transform(image)

Image play with cv2,PIL.Image