文章目录
第一步:构建路径与种类的映射关系
import os
from torch.utils.data import Dataset
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
def __len__(self):
pass
def __getitem__(self, idx):
pass
def main():
db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
if __name__ == '__main__':
main()
第二步:载入所有的宝可梦图像
import os,glob
from torch.utils.data import Dataset
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.load_csv('images.csv')
def load_csv(self,filename):
images = []
for name in self.name2label.keys():
images +=glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
print(len(images),images)
def __len__(self):
pass
def __getitem__(self, idx):
pass
def main():
db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
if __name__ == '__main__':
main()
第三步:打撒顺序并通过路径名提取映射关系构建映射文件
import csv
import os,glob
import random
from torch.utils.data import Dataset
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images,self.labels = self.load_csv('images.csv')
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = []
for name in self.name2label.keys():
images +=glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
print(len(images),images)
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images :
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img,label])
print('writen into csv file',filename)
images,labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img , label = row
label = int (label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images,labels
def __len__(self):
pass
def __getitem__(self, idx):
pass
def main():
db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
if __name__ == '__main__':
main()
第四步:完善选取、获取图片信息功能并可视化
import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images,self.labels = self.load_csv('images.csv')
if model == 'train':
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif model == 'val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
else :
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.images)):]
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = []
for name in self.name2label.keys():
images +=glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
print(len(images),images)
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images :
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img,label])
print('writen into csv file',filename)
images,labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img , label = row
label = int (label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images,labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
#img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
img , label = self.images[idx],self.labels[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'),
transforms.Resize((self.resize,self.resize)),
transforms.ToTensor()
])
img = tf(img)
label = torch.tensor(label)
return img,label
def main():
import visdom
viz = visdom.Visdom()
db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
# 得到迭代器第一个样本
x,y = next(iter(db))
print('sample:',x.shape,y.shape)
viz.images(x,win='sample_x',opts=dict(title='sample_x'))
if __name__ == '__main__':
main()
第五步:对数据进行预处理
import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images,self.labels = self.load_csv('images.csv')
if model == 'train':
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif model == 'val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
else :
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.images)):]
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = []
for name in self.name2label.keys():
images +=glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
print(len(images),images)
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images :
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img,label])
print('writen into csv file',filename)
images,labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img , label = row
label = int (label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images,labels
def __len__(self):
return len(self.images)
def denormalize(self,x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hat = (x-mean)/std
# x = x_hat*std + mean
# mean: [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
def __getitem__(self, idx):
#img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
img , label = self.images[idx],self.labels[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),#大小放缩
transforms.RandomRotation(15),#随机旋转
transforms.CenterCrop(self.resize),#中心裁剪
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],#通用数据
std=[0.229, 0.224, 0.225])
])
img = tf(img)
label = torch.tensor(label)
return img,label
def main():
import visdom
import time
viz = visdom.Visdom()
db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
# 得到迭代器第一个样本
x,y = next(iter(db))
print('sample:',x.shape,y.shape)
viz.images(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
if __name__ == '__main__':
main()
如果没有denormalize生成图片如下:
第六步:批量读取图片
import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images,self.labels = self.load_csv('images.csv')
if model == 'train':
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif model == 'val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
else :
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.images)):]
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = []
for name in self.name2label.keys():
images +=glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
print(len(images),images)
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images :
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img,label])
print('writen into csv file',filename)
images,labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img , label = row
label = int (label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images,labels
def __len__(self):
return len(self.images)
def denormalize(self,x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hat = (x-mean)/std
# x = x_hat*std + mean
# mean: [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
def __getitem__(self, idx):
#img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
img , label = self.images[idx],self.labels[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),#大小放缩
transforms.RandomRotation(15),#随机旋转
transforms.CenterCrop(self.resize),#中心裁剪
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],#通用数据
std=[0.229, 0.224, 0.225])
])
img = tf(img)
label = torch.tensor(label)
return img,label
def main():
import visdom
import time
viz = visdom.Visdom()
db =Pokeman('D:\pythonProject\pythonProject39\pokeman',64,'train')
# 得到迭代器第一个样本
x,y = next(iter(db))
print('sample:',x.shape,y.shape)
viz.images(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
loader = DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x ,y in loader:
viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
time.sleep(10)
if __name__ == '__main__':
main()
对于分类分类有序的结构可以更简单的调用API
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
])
db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
loader = DataLoader(db, batch_size=32, shuffle=True)
print(db.class_to_idx)
for x,y in loader:
viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)