框架:keras
数据集:CIFAR10
模型:vgg16
注:vgg16模型的输入图像尺寸至少为 48*48
思路:去掉vgg16的顶层,保留其余的网络结构与训练好的权重。然后添加模型结构,进而训练CIFAR10。
1.模型结构
2.具体代码以及注释
①训练代码
#-*- coding: utf-8 -*-
#迁移学习,vgg16+cifar10
from keras.applications.vgg16 import VGG16
from keras.layers import Dense, Flatten, Dropout
from keras.models import Model
from keras.optimizers import SGD
from keras.datasets import cifar10
import cv2 #加载opencv,为了后期能够修改图像尺寸
from keras import datasets
import h5py as h5py
import numpy as np
ishape = 64
model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(ishape,ishape, 3))
#include_top=False 表示将vgg16顶层去掉,只保留网络结构
for layers in model_vgg.layers:
layers.trainable = False
#layers.trainable = False将不需要重新训练的权重“冷冻”起来
model = Flatten()(model_vgg.output)
model = Dense(4096, activation='relu',name='fc1')(model)
model = Dense(4096, activation='relu',name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(10, activation='softmax',name='prediction')(model)
model_vgg_cifar10_pretrain = Model(inputs=model_vgg.input, outputs=model, name='vgg16_pretrain')
model_vgg_cifar10_pretrain.summary()
sgd = SGD(lr=0.05, decay=1e-5)
model_vgg_cifar10_pretrain.compile(optimizer=sgd, loss='categorical_crossentropy',
metrics=['accuracy'])
#将CIFAR10转为所需尺寸
(X_train,y_train),(X_test,y_test) = cifar10.load_data()X_train = [cv2.resize(i,(ishape,ishape)) for i in X_train]
X_test = [cv2.resize(i,(ishape,ishape)) for i in X_test]
X_train = np.concatenate([arr[np.newaxis] for arr in X_train] ).astype('float32')
X_test = np.concatenate([arr[np.newaxis] for arr in X_test] ).astype('float32')
#预处理
print(X_train[0].shape)
print(y_train[0])
X_train = X_train/255
X_test = X_test/255
np.where(X_train[0] != 0)
#哑编码
def train_y(y):
y_one = np.zeros(10)
y_one[y] = 1
return y_one
y_train_one = np.array([train_y(y_train[i]) for i in range(len(y_train))])
y_test_one = np.array([train_y(y_test [i]) for i in range(len(y_test ))])
#模型训练
model_vgg_cifar10_pretrain.fit(X_train, y_train_one, validation_data=(X_test, y_test_one),
epochs=50, batch_size=128)
model_vgg_cifar10_pretrain.save('cifar10.h5')
②识别代码
#-*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as processimage
from keras.models import load_model
from scipy import misc
import scipy
#load trained model
model = load_model('cifar10.h5') #导入模型
class MainPredictImg(object):
def __init__(self):
pass
def pre(self,filename):
pred_img = processimage.imread(filename)#read image
pred_img = np.array(pred_img) #transfer to array np
pred_img = scipy.misc.imresize(pred_img,size = (64, 64)) #将任意尺寸的图片resize成网络要求的尺寸
pred_img = pred_img.reshape(-1, 64, 64, 3)
prediction = model.predict(pred_img) #predict
labels = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
Final_prediction = [result.argmax() for result in prediction][0]
Final_prediction = labels[Final_prediction]
a = 0
for i in prediction[0]:
print labels[a]
print 'Percent:{:.30%}'.format(i) # 30%输出小数点后30位
a = a+1
return Final_prediction
def main():
Predict = MainPredictImg()
res = Predict.pre('airplant.jpg') #导入要识别的图片
print 'your picture is :-->',res
if __name__ == '__main__':
main()
3.识别结果:
参考书籍:
《Keras快速上手:基于Python的深度学习实战(谢梁等)》