ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)

时间:2022-07-20 00:55:16

第一步:构建路径与种类的映射关系

import os
from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():
    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')

if __name__ == '__main__':
    main()

ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)

第二步:载入所有的宝可梦图像

import os,glob

from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.load_csv('images.csv')
    def load_csv(self,filename):
        images = []
        for name in self.name2label.keys():
            images +=glob.glob(os.path.join(self.root,name,'*.png'))
            images += glob.glob(os.path.join(self.root, name, '*.jpg'))
            images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
        #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
        print(len(images),images)

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')

if __name__ == '__main__':
    main()

ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)

第三步:打撒顺序并通过路径名提取映射关系构建映射文件

import csv
import os,glob
import random

from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

            images,labels = [],[]
            with open(os.path.join(self.root,filename)) as f:
                reader = csv.reader(f)
                for row in reader:
                    img , label = row
                    label = int (label)
                    images.append(img)
                    labels.append(label)
            assert  len(images) == len(labels)
            return images,labels

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')

if __name__ == '__main__':
    main()

ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)
ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)

第四步:完善选取、获取图片信息功能并可视化

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((self.resize,self.resize)),
            transforms.ToTensor()
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    viz = visdom.Visdom()

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(x,win='sample_x',opts=dict(title='sample_x'))

if __name__ == '__main__':
    main()

ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)

ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)

第五步:对数据进行预处理

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def denormalize(self,x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std + mean
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x

    def __getitem__(self, idx):
        #img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),#大小放缩
            transforms.RandomRotation(15),#随机旋转
            transforms.CenterCrop(self.resize),#中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],#通用数据
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    import time
    viz = visdom.Visdom()

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))

if __name__ == '__main__':
    main()

ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)
如果没有denormalize生成图片如下:
ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)

第六步:批量读取图片

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def denormalize(self,x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std + mean
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x

    def __getitem__(self, idx):
        #img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),#大小放缩
            transforms.RandomRotation(15),#随机旋转
            transforms.CenterCrop(self.resize),#中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],#通用数据
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    import time
    viz = visdom.Visdom()

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',64,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
    loader = DataLoader(db,batch_size=32,shuffle=True,num_workers=8)

    for x ,y in loader:
        viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
        viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
        time.sleep(10)


if __name__ == '__main__':
    main()

ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)
对于分类分类有序的结构可以更简单的调用API

tf = transforms.Compose([
                transforms.Resize((64,64)),
                transforms.ToTensor(),
])
db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
loader = DataLoader(db, batch_size=32, shuffle=True)

print(db.class_to_idx)

for x,y in loader:
    viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
    viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

    time.sleep(10)