机器学习笔记(二十四)——Tensorflow 2(数据增强与迁移学习)

时间:2024-03-17 22:46:26

本博客仅用于个人学习,不用于传播教学,主要是记自己能够看得懂的笔记(

学习知识来自:【吴恩达团队Tensorflow2.0实践系列课程第一课】TensorFlow2.0中基于TensorFlow2.0的人工智能、机器学习和深度学习简介及基础编程_哔哩哔哩_bilibili

数据增强:以图片为例,将图片进行放大、旋转、平移、翻转等操作,就是一种数据增强的方式,因为测试数据中的目标不一定在图片*,也不一定是固定的姿态。

迁移学习:自己run辣么大的数据肯定很累啦,所以要用踩在巨人的肩膀上发展。迁移学习就是借用别人做过的模型,在其基础上训练自己的数据,以达到更好的效果。

数据可见:机器学习笔记(二十三)——Tensorflow 2(可视化) - Lcy的瞎bb - 博客园 (cnblogs.com)

别人的模型:https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5

代码:

import os
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator as idg
from tensorflow.keras.preprocessing import image
from tensorflow.keras import Model
from tensorflow.keras import layers
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.optimizers import RMSprop
filepath=\'E:/Python_Files/tensorflow/transfer-learning\'
local=filepath+\'/tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5\'
#这是网上的大佬自己训练了百万个数据得到的神经网络,这里我们就用ta的来迁移学习
files=[]
for root,dirs,files in os.walk(filepath+\'/tmp/test\'):
    used_up_variable=0

class myCallback(tf.keras.callbacks.Callback): #Callback类的继承类
    def on_epoch_end(self,epoch,logs={}): #重写on_epoch_end函数
        if logs.get(\'val_acc\')>0.95:
            print(\'\nReached 95% accuracy so canceling training.\')
            self.model.stop_training=True #达到条件,停止训练

callback=myCallback()

p_model=InceptionV3( #keras内部自带的神经网络结构
    input_shape=(150,150,3),
    include_top=False, #该神经网络顶部是Dense层,可以不需要
    weights=None #内置的权值都不要
)
p_model.load_weights(local) #将大佬的模型导入到这个结构中
for layer in p_model.layers:
    layer.trainable=False #固定,不能训练
#print(p_model.summary())

last_layer=p_model.get_layer(\'mixed7\') #选中其中一层,一般是最底层,此处不是
last_output=last_layer.output #得到输出

x=layers.Flatten()(last_output)
x=layers.Dense(1024,activation=\'relu\')(x)
x=layers.Dropout(0.2)(x)
x=layers.Dense(1,activation=\'sigmoid\')(x) #搞几个连接层
model=Model(p_model.input,x) #以此建立新模型
model.compile(
    optimizer=RMSprop(lr=0.0001),
    loss=\'binary_crossentropy\',
    metrics=[\'acc\']
)

datagen=idg( #此处为增强数据的各种方法
    rescale=1./255, #归一化
    rotation_range=40, #随机旋转-40~40度
    width_shift_range=0.2, #随机改变宽度20%
    height_shift_range=0.2, #随机改变高度20%
    shear_range=0.2, #随机倾斜裁剪20%
    zoom_range=0.2, #随机放大20%
    horizontal_flip=True, #随机翻转
    fill_mode=\'nearest\' #图像改变后,需要将像素重新填充,用最近像素的特征来填充
) #数据增强的generator
traingen=datagen.flow_from_directory( #训练数据集,并用文件夹名作为标签分类
    filepath+\'/tmp/train\', #数据集所在地址
    target_size=(150,150), #自动生成300*300的图片
    batch_size=10, #这个不能太大,不然会超内存,所以我这个程序运行得贼慢
    class_mode=\'binary\' #二分类模式
)
valigen=datagen.flow_from_directory( #验证数据集
    filepath+\'/tmp/validation\',
    target_size=(150,150),
    batch_size=10,
    class_mode=\'binary\'
)

model.fit( #训练模型
    traingen, #训练数据集generator
    steps_per_epoch=200, #注意前面的batch_size,这两个乘起来要大于等于数据集个数
    epochs=100,
    validation_data=valigen, #验证数据集generator
    validation_steps=100, #这个与验证数据集的batch_size也是一样,乘起来大于等于验证数据集个数
    callbacks=[callback]
)

for file in files:
    pat=filepath+\'/tmp/test/\'+file
#    img=cv2.imdecode(np.fromfile(pat,dtype=np.uint8),-1)
    img=image.load_img(pat,target_size=(150,150)) #导入图像
    x=image.img_to_array(img) #变成array类
    imgs=np.expand_dims(x,axis=0) #增加一个维度

#    imgs=np.vstack([imgs])
    imgs=imgs/255.0 #归一化,应该可以不用
    classes=model.predict(imgs,batch_size=10)
    print(classes[0]) #输出预测值
    if classes[0]>0.5:
        print(file+\' is a dog.\')
    else:
        print(file+\' is a cat.\')

得到结果:

Reached 95% accuracy so canceling training.
200/200 [==============================] - 40s 201ms/step - loss: 0.0703 - acc: 0.9760 - val_loss: 0.2258 - val_acc: 0.9520
[7.234914e-18]
cat.jpg is a cat.
[5.679721e-19]
cat1.jpg is a cat.
[0.00087641]
cat2.jpg is a cat.
[0.00128891]
cat3.jpg is a cat.
[0.999015]
cat4.jpg is a dog.
[7.817044e-12]
cat5.jpg is a cat.
[1.2193369e-11]
cat6.jpg is a cat.
[4.3203903e-09]
cat_cut.png is a cat.
[1.8515493e-10]
cat_from_traindata.jpg is a cat.
[1.]
dog.jpg is a dog.
[0.9999778]
dog1.jpg is a dog.
[1.]
dog2.jpg is a dog.
[1.]
dog3.jpg is a dog.
[1.]
dog4.jpg is a dog.
[0.9444226]
dog5.jpg is a dog.
[1.]
dog6.jpg is a dog.
[1.]
dog7.jpg is a dog.
[1.]
dog_from_traindata.jpg is a dog.

结果还不错啊。

参考博客:

Keras ImageDataGenerator参数_jacke121的专栏-CSDN博客_imagedatagenerator