一、数据集介绍
点击打开链接17_category_flower 是一个不同种类鲜花的图像数据,包含 17 不同种类的鲜花,每类 80 张该类鲜花的图片,鲜花种类是英国地区常见鲜花。下载数据后解压文件,然后将不同的花剪切到对应的文件夹,如下图所示:
每个文件夹下面有80个图片文件。
二、使用的工具
首先是在tensorflow框架下,然后介绍一下用到的两个库,一个是os,一个是pil。pil(python imaging library)是 python 中最常用的图像处理库,而image类又是 pil库中一个非常重要的类,通过这个类来创建实例可以有直接载入图像文件,读取处理过的图像和通过抓取的方法得到的图像这三种方法。
三、代码实现
我们是通过tfrecords来创建数据集的,tfrecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件(label)。
1、制作tfrecords文件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
import os
import tensorflow as tf
from pil import image # 注意image,后面会用到
import matplotlib.pyplot as plt
import numpy as np
cwd = 'd:\pycharm community edition 2017.2.3\work\google_net\jpg\\'
classes = { 'daffodil' , 'snowdrop' , 'lilyvalley' , 'bluebell' , 'crocus' , 'iris' , 'tigerlily' , 'tulip' , 'fritiuary' ,
'sunflower' , 'daisy' , 'coltsfoot' , 'dandelion' , 'cowslip' , 'buttercup' , 'windflower' , 'pansy' } # 花为 设定 17 类
writer = tf.python_io.tfrecordwriter( "flower_train.tfrecords" ) # 要生成的文件
for index, name in enumerate (classes):
class_path = cwd + name + '\\'
for img_name in os.listdir(class_path):
img_path = class_path + img_name # 每一个图片的地址
img = image. open (img_path)
img = img.resize(( 224 , 224 ))
img_raw = img.tobytes() # 将图片转化为二进制格式
example = tf.train.example(features = tf.train.features(feature = {
"label" : tf.train.feature(int64_list = tf.train.int64list(value = [index])),
'img_raw' : tf.train.feature(bytes_list = tf.train.byteslist(value = [img_raw]))
})) # example对象对label和image数据进行封装
writer.write(example.serializetostring()) # 序列化为字符串
writer.close()
|
首先将文件移动到对应的路径:
d:\pycharm community edition 2017.2.3\work\google_net\jpg
然后对每个文件下的图片进行读写和相应的大小惊醒改变,具体过程是使用tf.train.example来定义我们要填入的数据格式,其中label即为标签,也就是最外层的文件夹名字,img_raw为易经理二进制化的图片。然后使用tf.python_io.tfrecordwriter来写入。基本的,一个example中包含features,features里包含feature(这里没s)的字典。最后,feature里包含有一个 floatlist, 或者bytelist,或者int64list。就这样,我们把相关的信息都存到了一个文件中,所以前面才说不用单独的label文件。而且读取也很方便。
执行完以上代码就会出现如下图所示的tf文件
2、读取tfrecord文件
制作完文件后,将该文件读入到数据流中,具体代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
def read_and_decode(filename): # 读入dog_train.tfrecords
filename_queue = tf.train.string_input_producer([filename]) # 生成一个queue队列
reader = tf.tfrecordreader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features = {
'label' : tf.fixedlenfeature([], tf.int64),
'img_raw' : tf.fixedlenfeature([], tf.string),
}) # 将image数据和label取出来
img = tf.decode_raw(features[ 'img_raw' ], tf.uint8)
img = tf.reshape(img, [ 224 , 224 , 3 ]) # reshape为128*128的3通道图片
img = tf.cast(img, tf.float32) * ( 1. / 255 ) - 0.5 # 在流中抛出img张量
label = tf.cast(features[ 'label' ], tf.int32) # 在流中抛出label张量
return img, label
|
注意,feature的属性“label”和“img_raw”名称要和制作时统一 ,返回的img数据和label数据一一对应。
3、显示tfrecord格式的图片
为了知道tf 文件的具体内容,或者是怕图片对应的label出错,可以将数据流以图片的形式读出来并保存以便查看,具体的代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
filename_queue = tf.train.string_input_producer([ "flower_train.tfrecords" ]) # 读入流中
reader = tf.tfrecordreader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features = {
'label' : tf.fixedlenfeature([], tf.int64),
'img_raw' : tf.fixedlenfeature([], tf.string),
}) # 取出包含image和label的feature对象
image = tf.decode_raw(features[ 'img_raw' ], tf.uint8)
image = tf.reshape(image, [ 224 , 224 , 3 ])
label = tf.cast(features[ 'label' ], tf.int32)
label = tf.one_hot(label, 17 , 1 , 0 )
with tf.session() as sess: # 开始一个会话
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.coordinator()
threads = tf.train.start_queue_runners(coord = coord)
for i in range ( 100 ):
example, l = sess.run([image, label]) # 在会话中取出image和label
img = image.fromarray(example, 'rgb' ) # 这里image是之前提到的
img.save(cwd + str (i) + '_' 'label_' + str (l) + '.jpg' ) # 存下图片
print (example, l)
coord.request_stop()
coord.join(threads)
|
执行以上代码后,当前项目对应的文件夹下会生成100张图片,还有对应的label,如下图所示:
在这里我们可以看到,前80个图片文件的label是1,后20个图片的label是2。 由此可见,我们一开始制作tfrecord文件时,图片分类正确。
完整代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
|
import os
import tensorflow as tf
from pil import image # 注意image,后面会用到
import matplotlib.pyplot as plt
import numpy as np
cwd = 'd:\pycharm community edition 2017.2.3\work\google_net\jpg\\'
classes = { 'daffodil' , 'snowdrop' , 'lilyvalley' , 'bluebell' , 'crocus' , 'iris' , 'tigerlily' , 'tulip' , 'fritiuary' ,
'sunflower' , 'daisy' , 'coltsfoot' , 'dandelion' , 'cowslip' , 'buttercup' , 'windflower' , 'pansy' } # 花为 设定 17 类
writer = tf.python_io.tfrecordwriter( "flower_train.tfrecords" ) # 要生成的文件
for index, name in enumerate (classes):
class_path = cwd + name + '\\'
for img_name in os.listdir(class_path):
img_path = class_path + img_name # 每一个图片的地址
img = image. open (img_path)
img = img.resize(( 224 , 224 ))
img_raw = img.tobytes() # 将图片转化为二进制格式
example = tf.train.example(features = tf.train.features(feature = {
"label" : tf.train.feature(int64_list = tf.train.int64list(value = [index])),
'img_raw' : tf.train.feature(bytes_list = tf.train.byteslist(value = [img_raw]))
})) # example对象对label和image数据进行封装
writer.write(example.serializetostring()) # 序列化为字符串
writer.close()
def read_and_decode(filename): # 读入dog_train.tfrecords
filename_queue = tf.train.string_input_producer([filename]) # 生成一个queue队列
reader = tf.tfrecordreader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features = {
'label' : tf.fixedlenfeature([], tf.int64),
'img_raw' : tf.fixedlenfeature([], tf.string),
}) # 将image数据和label取出来
img = tf.decode_raw(features[ 'img_raw' ], tf.uint8)
img = tf.reshape(img, [ 224 , 224 , 3 ]) # reshape为128*128的3通道图片
img = tf.cast(img, tf.float32) * ( 1. / 255 ) - 0.5 # 在流中抛出img张量
label = tf.cast(features[ 'label' ], tf.int32) # 在流中抛出label张量
return img, label
filename_queue = tf.train.string_input_producer([ "flower_train.tfrecords" ]) # 读入流中
reader = tf.tfrecordreader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features = {
'label' : tf.fixedlenfeature([], tf.int64),
'img_raw' : tf.fixedlenfeature([], tf.string),
}) # 取出包含image和label的feature对象
image = tf.decode_raw(features[ 'img_raw' ], tf.uint8)
image = tf.reshape(image, [ 224 , 224 , 3 ])
label = tf.cast(features[ 'label' ], tf.int32)
label = tf.one_hot(label, 17 , 1 , 0 )
with tf.session() as sess: # 开始一个会话
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.coordinator()
threads = tf.train.start_queue_runners(coord = coord)
for i in range ( 100 ):
example, l = sess.run([image, label]) # 在会话中取出image和label
img = image.fromarray(example, 'rgb' ) # 这里image是之前提到的
img.save(cwd + str (i) + '_' 'label_' + str (l) + '.jpg' ) # 存下图片
print (example, l)
coord.request_stop()
coord.join(threads)
|
本人也是刚刚学习深度学习,能力有限,不足之处请见谅,欢迎大牛一起讨论,共同进步!
以上这篇对python制作自己的数据集实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/zhangjunp3/article/details/79627824