.CIFAR10模块使用讲解

时间:2024-11-11 08:11:08

数据集

所有的数据集都是的子类, 它们实现了__getitem__和__len__方法。因此,它们都可以传递给.

 transform = (
  [(),
 ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

 trainset = .CIFAR10(root='./data', train=True,
  download=True, transform=transform)
 trainloader = (trainset, batch_size=4,
  shuffle=True, num_workers=2)

目前为止,收录的数据集包括:

__all__ = ('LSUN', 'LSUNClass',
           'ImageFolder', 'DatasetFolder', 'FakeData',
           'CocoCaptions', 'CocoDetection',
           'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
           'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
           'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
           'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
           'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
           'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101',
           'Places365', 'Kitti', "INaturalist", "LFWPeople", "LFWPairs"
           )

.CIFAR10模块

CIFAR-10和CIFAR-100为8000万张微小图像数据集的子集。它们是由 Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.收集的。
CIFAR-10数据集包含10类60000幅32x32彩色图像,每个类6000幅图像。训练图像50000张,测试图像10000张。该数据集被分为5个训练批和1个测试批,每个批包含10000张图像。测试批正好包含从每个类中随机选择的1000张图像。训练批以随机顺序包含剩余的图像,但有些训练批可能包含一个类的图像多于另一个类。在它们之间,训练批恰好包含来自每个类的5000张图像。
以下是数据集中的类,以及每个类的10张随机图片:
在这里插入图片描述
这些类是完全互斥的。汽车和卡车之间没有重叠。“汽车”包括轿车、越野车之类的东西。“卡车”只包括大卡车。这两项都不包括皮卡。

.CIFAR10源码

class CIFAR10(VisionDataset):
    """`CIFAR10 </~kriz/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. , ````
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
    base_folder = 'cifar-10-batches-py'
    url = "/~kriz/"
    filename = ""
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': '',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:

        super(CIFAR10, self).__init__(root, transform=transform,
                                      target_transform=target_transform)

         = train  # training set or test set

        if download:
            ()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        if :
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        : Any = []
         = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = (, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = (f, encoding='latin1')
                (entry['data'])
                if 'labels' in entry:
                    (entry['labels'])
                else:
                    (entry['fine_labels'])

         = ().reshape(-1, 3, 32, 32)
         = ((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

    def _load_meta(self) -> None:
        path = (, self.base_folder, ['filename'])
        if not check_integrity(path, ['md5']):
            raise RuntimeError('Dataset metadata file not found or corrupted.' +
                               ' You can use download=True to download it')
        with open(path, 'rb') as infile:
            data = (infile, encoding='latin1')
             = data[['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate()}

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = [index], [index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = (img)

        if  is not None:
            img = (img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len()

    def _check_integrity(self) -> bool:
        root = 
        for fentry in (self.train_list + self.test_list):
            filename, md5 = fentry[0], fentry[1]
            fpath = (root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self) -> None:
        if self._check_integrity():
            print('Files already downloaded and verified')
            return
        download_and_extract_archive(, , filename=, md5=self.tgz_md5)

    def extra_repr(self) -> str:
        return "Split: {}".format("Train" if  is True else "Test")

参数说明

root (string):数据集所在目录的根目录 如果download设置为True。“cifar-10-batches-py '”存在,则将被保存至该目录
train :如果为True,则从训练集创建数据集,否则从测试集创建。
transform::(bool,可选)一个接受PIL图像的函数/变换 并返回转换后的版本。 如”transforms“,“RandomCrop”
target_transform:(可调用,可选):一个作用于目的转换函数。
download (bool,可选):如果为true,则从internet下载数据集 ,将其放在根目录中。 如果数据集已经下载,则不会 再次下载。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
 download=True, transform=transform)

在这里插入图片描述