1. 安装
pip install graphviz pip install pydot pip install pydot-ng # 版本兼容需要 # 测试一下 from keras.utils.visualize_util import plot
2. 使用:
#!/usr/bin/env python # coding=utf-8 """ 利用keras cnn进行端到端的验证码识别, 简单直接暴力。 迭代100次可以达到95%的准确率,但是很容易过拟合,泛化能力糟糕, 除了增加训练数据还没想到更好的方法. __autho__: jkmiao __email__: miao1202@126.com ___date__:2017-02-08 """ from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten, Activation, LSTM, Reshape from keras.layers import Convolution2D, MaxPooling2D from PIL import Image import os, random import numpy as np from keras.models import model_from_json from util import CharacterTable from keras.callbacks import ModelCheckpoint from sklearn.model_selection import train_test_split from keras.utils.visualize_util import plot def load_data(path='img/clearNoise/'): fnames = [os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('jpg')] random.shuffle(fnames) data, label = [], [] for fname in fnames: imgLabel = fname.split('/')[-1].split('_')[0] imgM = np.array(Image.open(fname).convert('L')) imgM = 1 * (imgM>180) data.append(imgM.reshape((imgM.shape[0], imgM.shape[1], 1))) label.append(imgLabel.lower()) return np.array(data), label ctable = CharacterTable() data, label = load_data() label_onehot = np.zeros((len(label), 216)) for i, lb in enumerate(label): label_onehot[i,:] = ctable.encode(lb) print data.shape print label_onehot.shape x_train, x_test, y_train, y_test = train_test_split(data, label_onehot, test_size=0.1) DEBUG = False # 建模 if DEBUG: model = Sequential() model.add(Convolution2D(32, 5, 5, border_mode='valid', input_shape=(60, 200, 1), name='conv1')) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2,2))) model.add(Convolution2D(32, 3, 3, name='conv2')) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2,2))) model.add(Flatten()) # model.add(Reshape((20, 60))) # model.add(LSTM(32)) model.add(Dense(512)) model.add(Activation('relu')) model.add(Dense(216)) model.add(Activation('softmax')) else: model = model_from_json(open('model/ba_cnn_model2.json').read()) model.load_weights('model/ba_cnn_model2.h5') # 编译 model.compile(loss='mse', optimizer='adam', metrics=['accuracy'], class_mode='categorical') model.summary()
# 绘图 plot(model, to_file='model.png', show_shapes=True) # 训练 check_pointer = ModelCheckpoint('./model/train_len_size1.h5', monitor='val_loss', verbose=1, save_best_only=True) model.fit(x_train, y_train, batch_size=32, nb_epoch=5, validation_split=0.1, callbacks=[check_pointer]) json_string = model.to_json() with open('./model/ba_cnn_model2.json', 'w') as fw: fw.write(json_string) model.save_weights('./model/ba_cnn_model2.h5') # 测试 y_pred = model.predict(x_test, verbose=1) cnt = 0 for i in range(len(y_pred)): guess = ctable.decode(y_pred[i]) correct = ctable.decode(y_test[i]) if guess == correct: cnt += 1 if i%10==0: print '--'*10, i print 'y_pred', guess print 'y_test', correct print cnt/float(len(y_pred))