【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用

时间:2023-02-23 17:35:40

一、前述

VGG16是由16层神经网络构成的经典模型,包括多层卷积,多层全连接层,一般我们改写的时候卷积层基本不动,全连接层从后面几层依次向前改写,因为先改参数较小的。

二、具体

1、因为本文中代码需要依赖OpenCV,所以第一步先安装OpenCV

【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用

因为VGG要求输入244*244,而数据集是28*28的,所以需要通过OpenCV在代码里去改变。

2、把模型下载后离线放入用户的管理目录下面,这样训练的时候就不需要从网上再下载了

【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用

3、我们保留的是除了全连接的所有层。

4、选择数据生成器,在真正使用的时候才会生成数据,加载到内存,前面yield只是做了一个标记

【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用

 代码:

# 使用迁移学习的思想,以VGG16作为模板搭建模型,训练识别手写字体
# 引入VGG16模块
from keras.applications.vgg16 import VGG16 # 其次加载其他模块
from keras.layers import Input
from keras.layers import Flatten
from keras.layers import Dense
from keras.layers import Dropout
from keras.models import Model
from keras.optimizers import SGD # 加载字体库作为训练样本
from keras.datasets import mnist # 加载OpenCV(在命令行中窗口中输入pip install opencv-python),这里为了后期对图像的处理,
# 大家使用pip install C:\Users\28542\Downloads\opencv_python-3.4.1+contrib-cp35-cp35m-win_amd64.whl
# 比如尺寸变化和Channel变化。这些变化是为了使图像满足VGG16所需要的输入格式
import cv2
import h5py as h5py
import numpy as np # 建立一个模型,其类型是Keras的Model类对象,我们构建的模型会将VGG16顶层(全连接层)去掉,只保留其余的网络
# 结构。这里用include_top = False表明我们迁移除顶层以外的其余网络结构到自己的模型中
# VGG模型对于输入图像数据要求高宽至少为48个像素点,由于硬件配置限制,我们选用48个像素点而不是原来
# VGG16所采用的224个像素点。即使这样仍然需要24GB以上的内存,或者使用数据生成器
model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(48, 48, 3))#输入进来的数据是48*48 3通道
#选择imagnet,会选择当年大赛的初始参数
#include_top=False 去掉最后3层的全连接层看源码可知
for layer in model_vgg.layers:
layer.trainable = False#别去调整之前的卷积层的参数
model = Flatten(name='flatten')(model_vgg.output)#去掉全连接层,前面都是卷积层
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(10, activation='softmax')(model)#model就是最后的y
model_vgg_mnist = Model(inputs=model_vgg.input, outputs=model, name='vgg16')
#把model_vgg.input X传进来
#把model Y传进来 就可以训练模型了 # 打印模型结构,包括所需要的参数
model_vgg_mnist.summary() #以下是原版的模型结构 224*224
model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
for layer in model_vgg.layers:
layer.trainable = False#别去调整之前的卷积层的参数
model = Flatten()(model_vgg.output)
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(10, activation='softmax', name='prediction')(model)
model_vgg_mnist_pretrain = Model(model_vgg.input, model, name='vgg16_pretrain') model_vgg_mnist_pretrain.summary() # 新的模型不需要训练原有卷积结构里面的1471万个参数,但是注意参数还是来自于最后输出层前的两个
# 全连接层,一共有1.2亿个参数需要训练
sgd = SGD(lr=0.05, decay=1e-5)#lr 学习率 decay 梯度的逐渐减小 每迭代一次梯度就下降 0.05*(1-(10的-5))这样来变
#随着越来越下降 学习率越来越小 步子越小
model_vgg_mnist.compile(loss='categorical_crossentropy',
optimizer=sgd, metrics=['accuracy']) # 因为VGG16对网络输入层需要接受3通道的数据的要求,我们用OpenCV把图像从32*32变成224*224,把黑白图像转成RGB图像
# 并把训练数据转化成张量形式,供keras输入
(X_train, y_train), (X_test, y_test) = mnist.load_data("../test_data_home")
X_train, y_train = X_train[:1000], y_train[:1000]#训练集1000条
X_test, y_test = X_test[:100], y_test[:100]#测试集100条
X_train = [cv2.cvtColor(cv2.resize(i, (48, 48)), cv2.COLOR_GRAY2RGB)
for i in X_train]#变成彩色的
#np.concatenate拼接到一起把
X_train = np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32') X_test = [cv2.cvtColor(cv2.resize(i, (48, 48)), cv2.COLOR_GRAY2RGB)
for i in X_test]
X_test = np.concatenate([arr[np.newaxis] for arr in X_test]).astype('float32') print(X_train.shape)
print(X_test.shape) X_train = X_train / 255
X_test = X_test / 255 def tran_y(y):
y_ohe = np.zeros(10)
y_ohe[y] = 1
return y_ohe y_train_ohe = np.array([tran_y(y_train[i]) for i in range(len(y_train))])
y_test_ohe = np.array([tran_y(y_test[i]) for i in range(len(y_test))]) model_vgg_mnist.fit(X_train, y_train_ohe, validation_data=(X_test, y_test_ohe),
epochs=100, batch_size=50)

 结果:

【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用

 自定义的网络层:

【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用

【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用的更多相关文章

  1. 利用神经网络算法的C#手写数字识别(二)

    利用神经网络算法的C#手写数字识别(二)   本篇主要内容: 让项目编译通过,并能打开图片进行识别.   1. 从上一篇<利用神经网络算法的C#手写数字识别>中的源码地址下载源码与资源, ...

  2. 利用神经网络算法的C#手写数字识别&lpar;一&rpar;

    利用神经网络算法的C#手写数字识别 转发来自云加社区,用于学习机器学习与神经网络 欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwri ...

  3. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

  4. NN:利用深度学习之神经网络实现手写数字识别&lpar;数据集50000张图片&rpar;—Jason niu

    import mnist_loader import network training_data, validation_data, test_data = mnist_loader.load_dat ...

  5. 手写数字识别——利用keras高层API快速搭建并优化网络模型

    在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...

  6. mnist手写数字识别——深度学习入门项目(tensorflow&plus;keras&plus;Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  7. keras框架的MLP手写数字识别MNIST,梳理?

    keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...

  8. keras—多层感知器MLP—MNIST手写数字识别

    一.手写数字识别 现在就来说说如何使用神经网络实现手写数字识别. 在这里我使用mind manager工具绘制了要实现手写数字识别需要的模块以及模块的功能:  其中隐含层节点数量(即神经细胞数量)计算 ...

  9. keras和tensorflow搭建DNN、CNN、RNN手写数字识别

    MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...

随机推荐

  1. 【leetcode❤python】 438&period; Find All Anagrams in a String

    class Solution(object):    def findAnagrams(self, s, p):        """        :type s: s ...

  2. Android中的Context详解

    前言:本文是我读<Android内核剖析>第7章 后形成的读书笔记 ,在此向欲了解Android框架的书籍推荐此书. 大家好,  今天给大家介绍下我们在应用开发中最熟悉而陌生的朋友---- ...

  3. 仿酷狗音乐播放器开发日志二十六 duilib在标题栏弹出菜单的方法

    转载请说明原出处,谢谢~~ 上篇日志说明了怎么让自定义控件响应右键消息.之后我给主窗体的标题栏增加右键响应,观察原酷狗后可以发现,在整个标题栏都是可以响应右键并弹出菜单的.应该的效果如下: 本以为像上 ...

  4. eclipse上 安装php插件

    首先在安装之前需要有eclipse   以及SDK环境已经搭建好 eclipse开发工具下载路径: http://dl.oschina.net/soft/eclipse java sdk下载路径: h ...

  5. Search a 2D Matrix【python】

    class Solution: # @param matrix, a list of lists of integers # @param target, an integer # @return a ...

  6. SELinux(Security-Enhanced Linux)

    http://blog.csdn.net/myarrow/article/details/9839377 Security-Enhanced Linux(SELinux)的历史 一个小历史将有助于帮助 ...

  7. Material Design学习-----FloatingActionButton

    FloatingActionButton是悬浮操作按钮,它继承自imageview,所以说它具备有imageview所有的方法和属性.与其他按钮不同的是,FloatingActionButton默认就 ...

  8. Ext 创建workspace package

    Ext 创建workspace package Package ExtJs Project 1. 创建工作区间文件目录 md wpt 2. 进入目录 cd wpt 3. 创建 创建工作区间 sench ...

  9. Unity-Shader-动态阴影&lpar;上&rpar; 投影的矩阵变换过程

    [旧博客转移 - 2017年1月20日 01:20 ] 前面的话 最近很长时间没写博文了,一是太忙 ( lan ) 了,二是这段时间又领悟了一些东西,脑子里很混乱,不知道从何写起.但感觉不能再拖延下去 ...

  10. yield next和yield&ast; next的区别

    yield next和yield* next之间到底有什么区别?为什么需要yield* next?经常会有人提出这个问题.虽然我们在代码中会尽量避免使用yield* next以减少新用户的疑惑,但还是 ...