torchvision中的数据集使用

时间:2024-04-07 10:31:53

torchvision中的数据集使用

在这里插入图片描述
在这里插入图片描述

使用和下载CIFAR10数据集

在这里插入图片描述

输出测试集中的第一个元素(输出img信息和target)

在这里插入图片描述

查看分类classes

打断点–>右键Debug–>找到classes
在这里插入图片描述

代码

import torchvision

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

print(test_set[0])
print(test_set.classes)

img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])  # 输出target对应的classes
img.show()  # 输出图片

在这里插入图片描述

将图片转换成tensor数据类型

import torchvision

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

print(test_set[0])  # 测试第一张图片

在这里插入图片描述

创建日志文件

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

# print(test_set[0])

writer = SummaryWriter("p10")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)
writer.close()
运行后在Terminal中输入(先进入pytorch环境中):
tensorboard --logdir="learn_pytorch/p10"  # 注意路径的选择,"p10"会报错

在这里插入图片描述
在这里插入图片描述