TensorFlow 框架--TFLearn使用

时间:2022-01-29 13:53:27
得益于TensorFlow 社区的繁荣,诞生出许多高质量的元框架(metaframework),如Keras、

TFLearn、TensorLayer 等。使用元框架能够大大减少编写TensorFlow 代码的工作量,方便开发
者很速搭建网络模型,并且使代码简单、可读性强。
TFLearn 是一个建立在TensorFlow 顶部的模块化的深度学习框架,它为TensorFlow 提供更
高级的API,以便于很速实验,同时保持完全透明和兼容。
1.TFLearn功能简介
TFLearn的功能包括:
(1)易于使用和理解高级API用于实施深层神经网络,教程和示例
(2)通过高度模块化的内置神经网络层,正则化器,优化器,度量,快速原型
(3)tensorflow安全透明。所有函数都建立在张量上,可以独立于tflearn的使用
(4)强大的帮助函数来训练任何tensorflow图,支持多个输入,输出和优化器
(5)简单和美丽的图形可视化,有关权重,渐变,激活等详细信息…
(6)轻松使用多个CPU/GPU的设备布局
下面介绍一个小案例,感受代码的简洁:
2 .加载数据*
这里用的是牛津大学的鲜花数据集(http://www.robots.ox.ac.uk/~vgg/data/flowers/17/)(Flower Dataset)。这个数据集提供了17 个类别的鲜
花数据,每个类别80 张图片,并且图片有大量的姿态和光的变化。
注意,在代码的开始需要导入用到的与卷积、池化、规范化相关的类,方法如下:
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression
import tflearn.datasets.oxflower17 as oxflower17
X, Y = oxflower17.load_data(one_hot=True, resize_pics=(227, 227))

3、 构建网络模型
构建AlexNet 网络模型时,直接使用TFLearn 中的卷积、池化、规范化、全连接、dropout
函数来构建即可。方法如下:
构建AlexNet 网络
network = input_data(shape=[None, 227, 227, 3])
network = conv_2d(network, 96, 11, strides=4, activation=’relu’)
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = conv_2d(network, 256, 5, activation=’relu’)
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = conv_2d(network, 384, 3, activation=’relu’)
network = conv_2d(network, 384, 3, activation=’relu’)
network = conv_2d(network, 256, 3, activation=’relu’)
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = fully_connected(network, 4096, activation=’tanh’)
network = dropout(network, 0.5)
network = fully_connected(network, 4096, activation=’tanh’)
network = dropout(network, 0.5)
network = fully_connected(network, 17, activation=’softmax’)
network = regression(network, optimizer=’momentum’,
loss=’categorical_crossentropy’,
learning_rate=0.001) # 回归操作,同时规定网络所使用的学习率、损失函数和优化器
“`
4、 训练模型
构建完模型之后,就可以训练模型了。这里我们加了一步,就是假设有训练好或训练到一
半的AlexNet 模型的检查点文件,直接载入,方法如下:
model = tflearn.DNN(network, checkpoint_path=’model_alexnet’,
max_checkpoints=1, tensorboard_verbose=2)
model.fit(X, Y, n_epoch=1000, validation_set=0.1, shuffle=True,
show_metric=True, batch_size=64, snapshot_step=200,
snapshot_epoch=False, run_id=’alexnet_oxflowers17’)