shuffle = False时,不打乱数据顺序
shuffle = True,随机打乱
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset
h5f = h5py. File ( 'train.h5' , 'w' );
data1 = np.array([[ 1 , 2 , 3 ],
[ 2 , 5 , 6 ],
[ 3 , 5 , 6 ],
[ 4 , 5 , 6 ]])
data2 = np.array([[ 1 , 1 , 1 ],
[ 1 , 2 , 6 ],
[ 1 , 3 , 6 ],
[ 1 , 4 , 6 ]])
h5f.create_dataset( str ( 'data' ), data = data1)
h5f.create_dataset( str ( 'label' ), data = data2)
class Dataset(Dataset):
def __init__( self ):
h5f = h5py. File ( 'train.h5' , 'r' )
self .data = h5f[ 'data' ]
self .label = h5f[ 'label' ]
def __getitem__( self , index):
data = torch.from_numpy( self .data[index])
label = torch.from_numpy( self .label[index])
return data, label
def __len__( self ):
assert self .data.shape[ 0 ] = = self .label.shape[ 0 ], "wrong data length"
return self .data.shape[ 0 ]
dataset_train = Dataset()
loader_train = DataLoader(dataset = dataset_train,
batch_size = 2 ,
shuffle = True )
for i, data in enumerate (loader_train):
train_data, label = data
print (train_data)
|
pytorch DataLoader使用细节
背景:
我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,
数据变换共有以下内容
1
2
3
4
5
|
composed = transforms.Compose([transforms.Resize(( 448 , 448 )), # resize
transforms.RandomCrop( 300 ), # random crop
transforms.ToTensor(),
transforms.Normalize(mean = [ 0.5 , 0.5 , 0.5 ], # normalize
std = [ 0.5 , 0.5 , 0.5 ])])
|
简单的数据读取类, 进返回PIL格式的image:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
class MyDataset(data.Dataset):
def __init__( self , labels_file, root_dir, transform = None ):
with open (labels_file) as csvfile:
self .labels_file = list (csv.reader(csvfile))
self .root_dir = root_dir
self .transform = transform
def __len__( self ):
return len ( self .labels_file)
def __getitem__( self , idx):
im_name = os.path.join(root_dir, self .labels_file[idx][ 0 ])
im = Image. open (im_name)
if self .transform:
im = self .transform(im)
return im
|
下面是主程序
1
2
3
4
5
6
7
8
9
10
11
|
labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform = composed)
dataloader = data.DataLoader(dataset_transform, batch_size = 1 , shuffle = False )
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张) """
for eopch in range ( 2 ):
plt.figure(figsize = ( 6 , 6 ))
for ind, i in enumerate (dataloader):
a = i[ 0 , :, :, :].numpy().transpose(( 1 , 2 , 0 ))
plt.subplot( 1 , 3 , ind + 1 )
plt.imshow(a)
|
从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_35752161/article/details/110875040