Pytorch系列:
- PyTorch系列(一) - PyTorch使用总览
- PyTorch系列(二) - PyTorch数据读取
- PyTorch系列(三) - PyTorch网络构建
- PyTorch系列(四) - PyTorch网络设置
参考:
本文首先介绍了有关预处理包的源码,接着介绍了在数据处理中的具体应用; 其主要目录如下:
1 PyTorch数据预处理以及源码分析 (torch.utils.data)
1.1 Dataset
Dataset
1
class torch.utils.data.Dataset
表示Dataset的抽象类。所有其他数据集都应该进行子类化。 所有子类应该override__len__
和__getitem__
,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
class Dataset(object):
# 强制所有的子类override getitem和len两个函数,否则就抛出错误;
# 输入数据索引,输出为索引指向的数据以及标签;
def __getitem__(self, index):
raise NotImplementedError
# 输出数据的长度
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
TensorDataset
class torch.utils.data.TensorDataset(*tensors)
Dataset的子类。包装tensors数据集;输入输出都是元组; 通过沿着第一个维度索引一个张量来回复每个样本。 个人感觉比较适用于数字类型的数据集,比如线性回归等。
class TensorDataset(Dataset):
def __init__(self, *tensor):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in tensors
def __len__(self):
return self.tensors[0].size(0)
ConcatDateset
class torch.utils.data.ConcatDateset(datasets)
连接多个数据集。 目的:组合不同的数据集,可能是大规模数据集,因为连续操作是随意连接的。 datasets的参数:要连接的数据集列表 datasets的样式:iterable
class ConcatDataset(Dataset):
@staticmethod
def cumsum(sequence):
# sequence是一个列表,e.g. [[1,2,3], [a,b], [4,h]]
# return 一个数据大小列表,[3, 5, 7], 明显看的出来包含数据多少,第一个代表第一个数据的大小,第二个代表第一个+第二数据的大小,最后代表所有的数据大学;
...
def __getitem__(self, idx):
# 主要是这个函数,通过bisect的类实现了任意索引数据的输出;
dataset_idx = bisect.bisect_right(self.cumulative_size, idx)
if dataset_idx == 0:
sample_idx == idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx -1]
return self.datasets[dataset_idx][sample_idx]
...
Subset
class torch.utils.data.Subset(dataset, indices)
选取特殊索引下的数据子集; dataset:数据集; indices:想要选取的数据的索引;
random_split
class torch.utils.data.random_split(dataset, lengths):
随机不重复分割数据集; dataset:要被分割的数据集 lengths:长度列表,e.g. [7, 3], 保证7+3=len(dataset)
1.2 DataLoader
DataLoader
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
数据加载器。 组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。 参数:
- dataset (Dataset) - 从中加载数据的数据集。
- batch_size (int, optional) - 批训练的数据个数。
- shuffle (bool, optional) - 是否打乱数据集(一般打乱较好)。
- sampler (Sampler, optional) - 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
- batch_sampler (Sample, optional) - 和sampler类似,返回批中的索引。
- num_workers (int, optional) - 用于数据加载的子进程数。
- collate_fn (callable, optional) - 合并样本列表以形成小批量。
- pin_memory (bool, optional) - 如果为True,数据加载器在返回去将张量复制到CUDA固定内存中。
- drop_last (bool, optional) - 如果数据集大小不能被batch_size整除, 设置为True可以删除最后一个不完整的批处理。
- timeout (numeric, optional) - 正数,收集数据的超时值。
- worker_init_fn (callabel, optional) - If not
None
, this will be called on each worker subprocess with the worker id (an int in[0, num_workers - 1]
) as input, after seeding and before data loading. (default:None
)
特别重要:DataLoader中是不断调用DataLoaderIter
DataLoaderIter
class _DataLoaderIter(loader)
从DataLoader’s数据中迭代一次。其上面DataLoader
功能都在这里; 插个眼,有空在分析这个
1.3 sampler
Sampler
class torch.utils.data.sampler.Sampler(data_source)
所有采样器的基础类; 每个采样器子类必须提供一个__iter__
方法,提供一种迭代数据集元素的索引的方法,以及返回迭代器长度__len__
方法。
class Sampler(object):
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
SequentialSampler
class torch.utils.data.SequentialSampler(data_source)
样本元素顺序排列,始终以相同的顺序。 参数:-data_source (Dataset) - 采样的数据
RandomSampler
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
样本随机排列,如果没有Replacement,将会从打乱的数据采样,否则,。。 参数:
- data_source (Dataset) - 采样数据
- num_samples (int) - 采样数据大小,默认是全部。
- replacement (bool) - 是否放回
SubsetRandomSampler
class torch.utils.data.SubsetRandomSampler(indices)
从给出的索引中随机采样,without replacement。 参数:
- indices (sequence) - 索引序列。
BatchSampler
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
将采样封装到批处理索引。 参数:
- sampler (sampler) - 基本采样
- batch_size (int) - 批大小
- drop_last (bool) - 是否删掉最后的批次
weightedRandomSampler
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
样本元素来自[0,…,len(weights)-1], 给定概率(权重)。 参数:
- weights (list) - 权重列表。不需要加起来为1
- num_samplers (int) - 要采样数目
- replacement (bool) -
1.4 Distributed
DistributedSampler
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None)
????没读呢
1.5 其它链接
2 torchvision
计算机视觉用到的库,文档以及码源如下:
- torchvision documentation
- torchvision 其库主要包含一下内容:
- torchvision.datasets
- MNIST
- Fashion-MNIST
- EMNIST
- COCO
- LSUN
- ImageFolder
- DatasetFolder
- Imagenet-12
- CIFAR
- STL10
- SVHN
- Photo Tour
- SBU
- Flickr
- VOC
- torchvision.models
- Alexnet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- Inception v3
- torchvision.transforms
- Transforms on PIL Image
- Transfroms on torch.* Tensor
- Conversion Transforms
- Generic Transforms
- Functional Transforms
- torchvision.utils
3 应用
3.1 init
具有一下图像数据如下表示:
- train
- normal
- 1.png
- 2.png
- …
- 8000.png
- tumor
- 1.png
- 2.png
- …
- 8000.png
- normal
- validation
- normal
- 1.png
- tumor
- 1.png
- normal
希望能够训练模型,使得能够识别tumor, normal两类,将tumor–>1, normal–>0。
3.2 数据读取
在PyTorch中数据的读取借口需要经过,Dataset和DatasetLoader (DatasetloaderIter)。下面就此分别介绍。
Dataset
首先导入必要的包。
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
np.random.seed(0)
其次定义MyDataset类,为了代码整洁精简,将不必要的操作全删,e.g. 图像剪切等。
class MyDataset(Dataset):
def __init__(self, root, size=229, ):
"""
Initialize the data producer
"""
self._root = root
self._size = size
self._num_image = len(os.listdir(root))
self._img_name = os.listdir(root)
def __len__(self):
return self._num_image
def __getitem__(self, index):
img = Image.open(os.path.join(self._root, self._img_name[index]))
# PIF image: H × W × C
# torch image: C × H × W
img = np.array(img, dtype-np.float32).transpose((2, 0, 1))
return img
DataLoader
将MyDataset封装到loader器中。
from torch.utils.data import DataLoader
# 实例化MyData
dataset_tumor_train = MyDataset(root=/img/train/tumor/)
dataset_normal_train = MyDataset(root=/img/train/normal/)
dataset_tumor_validation = MyDataset(root=/img/validation/tumor/)
dataset_normal_validation = MyDataset(root=/img/validation/normal/)
# 封装到loader
dataloader_tumor_train = DataLoader(dataset_tumor_train, batch_size=10)
dataloader_normal_train = DataLoader(dataset_normal_train, batch_size=10)
dataloader_tumor_validation = DataLoader(dataset_tumor_validation, batch_size=10)
dataloader_normal_validation = DataLoader(dataset_normal_validation, batch_size=10)
3.3 train_epoch
简单将数据流接口与训练连接起来
def train_epoch(model, loss_fn, optimizer, dataloader_tumor, dataloader_normal):
model.train()
# 由于tumor图像和normal图像一样多,所以将tumor,normal连接起来,steps=len(tumor_loader)=len(normal_loader)
steps = len(dataloader_tumor)
batch_size = dataloader_tumor.batch_size
dataiter_tumor = iter(dataloader_tumor)
dataiter_normal = iter(dataloader_normal)
for step in range(steps):
data_tumor = next(dataiter_tumor)
target_tumor = [1, 1,..,1] # 和data_tumor长度相同的tensor
data_tumor = Variable(data_tumor.cuda(async=True))
target_tumor = Variable(target_tumor.cuda(async=True))
data_normal = next(dataiter_normal)
target_normal = [0, 0,..,0] #
data_normal = Variable(data_normal.cuda(async=True))
target_normal = Variable(target_normal.cuda(async=True))
idx_rand = Variable(torch.randperm(batch_size*2).cuda(async=True))
data = torch.cat([data_tumor, data_normal])[idx_rand]
target = torch.cat([target_tumor, target_normal])[idx_rand]
output = model(data)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
probs = output.sigmoid()
【转载】PyTorch系列 (二):pytorch数据读取的更多相关文章
-
pytorch系列 -- 9 pytorch nn.init 中实现的初始化函数 uniform, normal, const, Xavier, He initialization
本文内容:1. Xavier 初始化2. nn.init 中各种初始化函数3. He 初始化 torch.init https://pytorch.org/docs/stable/nn.html#to ...
-
infobright系列二:数据迁移
安装之后把之前infobright的数据迁移到新安装的infobright上. 1:挺掉相关的服务 2:scp 把旧数据拷到新安装的infobright上 3:修改/etc/my-ib.cnf的数据目 ...
-
zico源代码分析(二) 数据读取和解析部分
第一部分:分析篇 首先,看一下zico的页面,左侧是hostname panel,右侧是该主机对应的traces panel. 点击左侧zorka主机名,右侧panel会更新信息,在火狐浏览器中使用f ...
-
Pytorch系列:(二)数据加载
DataLoader DataLoader(dataset,batch_size=1,shuffle=False,sampler=None, batch_sampler=None,num_worker ...
-
[Pytorch]PyTorch Dataloader自定义数据读取
整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...
-
Pytorch数据读取框架
训练一个模型需要有一个数据库,一个网络,一个优化函数.数据读取是训练的第一步,以下是pytorch数据输入框架. 1)实例化一个数据库 假设我们已经定义了一个FaceLandmarksDataset数 ...
-
Pytorch数据读取与预处理实现与探索
在炼丹时,数据的读取与预处理是关键一步.不同的模型所需要的数据以及预处理方式各不相同,如果每个*都我们自己写的话,是很浪费时间和精力的.Pytorch帮我们实现了方便的数据读取与预处理方法,下面记录 ...
-
Pytorch系列教程-使用字符级RNN对姓名进行分类
前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_classification_t ...
-
如何入门Pytorch之二:如何搭建实用神经网络
上一节中,我们介绍了Pytorch的基本知识,如数据格式,梯度,损失等内容. 在本节中,我们将介绍如何使用Pytorch来搭建一个经典的分类神经网络. 搭建一个神经网络并训练,大致有这么四个部分: 1 ...
随机推荐
-
easyui datagrid 每条数据后添加操作按钮
easyui datagrid 每条数据后添加“编辑.查看.删除”按钮 1.给datagrid添加操作字段:字段值 <table class="easyui-datagrid" ...
-
[zz] Install VSFTP
The first two letters of vsftpd stand for "very secure" and the program was built to have ...
-
使用Slua框架开发Unity项目的重要步骤
下载与安装 下载地址 GitHub 安装过程1.下载最新版,这里, 解压缩,将Assets目录里的所有内容复制到你的工程中,对于最终产品,可以删除slua_src,例子,文档等内容,如果是开发阶段则无 ...
-
Exp3 免杀原理与实践 20164314 郭浏聿
一.实践内容 1.正确使用msf编码器,msfvenom生成如jar之类的其他文件,veil-evasion,加壳工具,使用shellcode编程 2.通过组合应用各种技术实现恶意代码免杀(0.5分) ...
-
php aes128加密
//[加密数据]AES 128 ECB模式 public function aesEncrypt($str){ $screct_key = Yii::$app->params['encryptK ...
-
20155321 《网络攻防》 Exp1 PC平台逆向破解(5)M
20155321 <网络攻防> Exp1 PC平台逆向破解(5)M 实践目标 本次实践的对象是linux的可执行文件 该程序正常执行流程是:main调用foo函数,foo函数会简单回显任何 ...
-
vue中的ajax - axios
vue中的ajax - axios axios - 简书 使用 axios 实现 ajax 方案 VUE 更好的 ajax 上传处理 axios.js vue.js 自2.0版本已经不对 vue-re ...
-
Spring系列(一):Spring的基本概念及其核心
一.Spring是什么 Spring是一种多层的J2EE应用程序框架,其核心就是提供一种新的机制管理业务对象及其依赖关系. 二.为什么要使用Spring 1. 降低组件之间的耦合度,实现软件各层之间的 ...
-
Spring 定时操作业务需求
1.定时分析 在业务需求中有的需要检测用户的状态,通过对用户状态的检测做出对此状态相应的操作,如果这种检测由运营人工检测,不仅工作量大,而且准确性不高,人工无法很好的完成工作: 问题根源:在检测用户状 ...
-
css实现心形图案
用1个标签实现心形图案,show you the code; <!DOCTYPE html> <html lang="en"> <head> & ...