CNN实战基础——读取图片数据
实现结果如下图:
取数据
1.导包
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
2.下载数据
trans = transforms.ToTensor()
train_dataset = datasets.MNIST(
root='./drive/MyDrive/DataSet', train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(
root='./drive/MyDrive/DataSet', train=False, transform=trans, download=True)
root='xx'
:表示 数据集所在路径,数据集不用解压train=True or False
:表明是 训练集或测试集transform=trans
:把读取的Img类型图片转为Tensordown=True
:表示 若数据集在root路径下,则直接加载;若不在root里,则下载(外网很慢,可以提前下载好放进去)
下图为下载之后,自动创建了raw文件夹,数据集在raw里
3.加载数据
train_dl = DataLoader(train_dataset, batch_size=32)
param1:指明加载什么数据集
param2:一批有32张图
4.打印数据
我们往往想看看第一批(32张)图,但不能通过train_dl[0]
等下标方式访问,下面有两种方式都可以:
4.1 next + enumerate
examples = enumerate(train_dl) # 返回数字下标+迭代器(可用next访问)
next(examples)
4.2 next + iter
examples = iter(train_dl) # 返回迭代器,无下标
imgs, labels = next(examples)
结果可视化
第一版
import matplotlib.pyplot as plt
fig = plt.figure() # 创建一个窗口
for i in range(6):
plt.subplot(2,3,i+1)
plt.imshow(imgs[i][0])
plt.title('label:{}'.format(labels[i]))# 格式化输出,否则是tensor数据
plt.imshow(imgs[i][0])
这里指:imgs是4维数据,后面2维是28*28图片大小,不用管,关键是用前两维定位到图片
结果:
发现布局有问题,故增加plt.tight_layout()
语句,结果明显分开
发现x,y轴的数字可以不要,添加plt.xticks([]), plt.yticks([])
,结果:
想要灰度图,设置一个参数即可plt.imshow(imgs[i][0],cmap='gray')
完整代码:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
trans = transforms.ToTensor()
train_dataset = datasets.MNIST(
root='./drive/MyDrive/DataSet', train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(
root='./drive/MyDrive/DataSet', train=False, transform=trans, download=True)
train_dl = DataLoader(train_dataset, batch_size=32)
# train_dl里有很多batch,每个batch里有batch_size张图片
#examples = enumerate(train_dl)
#next(examples)
# 用iter会少序号,和enumerate有一点不同,其它差不多可以打印出来看
examples = iter(train_dl)
imgs, labels = next(examples)
# print(len(imgs), len(labels))
#-------------------------------数据显示--------------------------------------------
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(imgs[i][0],cmap='gray') # 4维数据,后面2维是28*28图片大小,不用管,关键是用前两维定位到图片
plt.title('label:{}'.format(labels[i]))# 格式化输出,否则是tensor数据
plt.xticks([]) # 清空x,y轴的数字
plt.yticks([])