传统方式需要10s,dat方式需要0.6s
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
import os
import time
import torch
import random
from common.coco_dataset import COCODataset
def gen_data(batch_size,data_path,target_path):
os.makedirs(target_path,exist_ok = True )
dataloader = torch.utils.data.DataLoader(COCODataset(data_path,
( 352 , 352 ),
is_training = False , is_scene = True ),
batch_size = batch_size,
shuffle = False , num_workers = 0 , pin_memory = False ,
drop_last = True ) # DataLoader
start = time.time()
for step, samples in enumerate (dataloader):
images, labels, image_paths = samples[ "image" ], samples[ "label" ], samples[ "img_path" ]
print ( "time" , images.size( 0 ), time.time() - start)
start = time.time()
# torch.save(samples,target_path+ '/' + str(step) + '.dat')
print (step)
def cat_100(target_path,batch_size = 100 ):
paths = os.listdir(target_path)
li = [i for i in range ( len (paths))]
random.shuffle(li)
images = []
labels = []
image_paths = []
start = time.time()
for i in range ( len (paths)):
samples = torch.load(target_path + str (li[i]) + ".dat" )
image, label, image_path = samples[ "image" ], samples[ "label" ], samples[ "img_path" ]
images.append(image.cuda())
labels.append(label.cuda())
image_paths.append(image_path)
if i % batch_size = = batch_size - 1 :
images = torch.cat((images), 0 )
print ( "time" , images.size( 0 ), time.time() - start)
images = []
labels = []
image_paths = []
start = time.time()
i + = 1
if __name__ = = '__main__' :
os.environ[ "CUDA_VISIBLE_DEVICES" ] = '3'
batch_size = 320
# target_path='d:/test_1000/'
target_path = 'd:\img_2/'
data_path = r 'D:\dataset\origin_all_datas\_2train'
gen_data(batch_size,data_path,target_path)
# get_data(target_path,batch_size)
# cat_100(target_path,batch_size)
|
这个读取数据也比较快:320 batch_size 450ms
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
def cat_100(target_path,batch_size = 100 ):
paths = os.listdir(target_path)
li = [i for i in range ( len (paths))]
random.shuffle(li)
images = []
labels = []
image_paths = []
start = time.time()
for i in range ( len (paths)):
samples = torch.load(target_path + str (li[i]) + ".dat" )
image, label, image_path = samples[ "image" ], samples[ "label" ], samples[ "img_path" ]
images.append(image) #.cuda())
labels.append(label) #.cuda())
image_paths.append(image_path)
if i % batch_size < batch_size - 1 :
i + = 1
continue
i + = 1
images = torch.cat(([image.cuda() for image in images]), 0 )
print ( "time" , images.size( 0 ), time.time() - start)
images = []
labels = []
image_paths = []
start = time.time()
|
补充:pytorch数据加载和处理问题解决方案
最近跟着pytorch中文文档学习遇到一些小问题,已经解决,在此对这些错误进行记录:
在读取数据集时报错:
AttributeError: 'Series' object has no attribute 'as_matrix'
在显示图片是时报错:
ValueError: Masked arrays must be 1-D
显示单张图片时figure一闪而过
在显示多张散点图的时候报错:
TypeError: show_landmarks() got an unexpected keyword argument 'image'
解决方案
主要问题在这一行: 最终目的是将Series转为Matrix,即调用np.mat即可完成。
修改前
1
|
landmarks = landmarks_frame.iloc[n, 1 :].as_matrix()
|
修改后
1
|
landmarks = np.mat(landmarks_frame.iloc[n, 1 :])
|
打散点的x和y坐标应该均为向量或列表,故将landmarks后使用tolist()方法即可
修改前
1
|
plt.scatter(landmarks[:, 0 ],landmarks[:, 1 ],s = 10 ,marker = '.' ,c = 'r' )
|
修改后
1
|
plt.scatter(landmarks[:, 0 ].tolist(),landmarks[:, 1 ].tolist(),s = 10 ,marker = '.' ,c = 'r' )
|
前面使用plt.ion()打开交互模式,则后面在plt.show()之前一定要加上plt.ioff()。这里直接加到函数里面,避免每次plt.show()之前都用plt.ioff()
修改前
1
2
3
4
5
|
def show_landmarks(imgs,landmarks):
'''显示带有地标的图片'''
plt.imshow(imgs)
plt.scatter(landmarks[:, 0 ].tolist(),landmarks[:, 1 ].tolist(),s = 10 ,marker = '.' ,c = 'r' ) #打上红色散点
plt.pause( 1 ) #绘图窗口延时
|
修改后
1
2
3
4
5
6
|
def show_landmarks(imgs,landmarks):
'''显示带有地标的图片'''
plt.imshow(imgs)
plt.scatter(landmarks[:, 0 ].tolist(),landmarks[:, 1 ].tolist(),s = 10 ,marker = '.' ,c = 'r' ) #打上红色散点
plt.pause( 1 ) #绘图窗口延时
plt.ioff()
|
网上说对于字典类型的sample可通过 **sample的方式获取每个键下的值,但是会报错,于是把输入写的详细一点,就成功了。
修改前
1
|
show_landmarks( * * sample)
|
修改后
1
|
show_landmarks(sample[ 'image' ],sample[ 'landmarks' ])
|
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。如有错误或未考虑完全的地方,望不吝赐教。
原文链接:https://blog.csdn.net/jacke121/article/details/85236561