MXNet官网案例分析--Train MLP on MNIST

时间:2021-01-18 19:53:14

本文是MXNet的官网案例: Train MLP on MNIST. MXNet所有的模块如下图所示:

MXNet官网案例分析--Train MLP on MNIST

第一步: 准备数据

从下面程序可以看出,MXNet里面的数据是一个4维NDArray.

import mxnet as mx

# mxnet.io.MXDataIter, shape=(128,1,28,28)
train = mx.io.MNISTIter(
image = '/home/zhaopace/MXNet/mxnet/example/adversary/data/train-images-idx3-ubyte',
label = '/home/zhaopace/MXNet/mxnet/example/adversary/data/train-labels-idx1-ubyte',
batch_size = 128,
data_shape = (784, )
)
# mxnet.io.MXDataIter, shape=(128,1,28,28)
val = mx.io.MNISTIter(
image = '/home/zhaopace/MXNet/mxnet/example/adversary/data/t10k-images-idx3-ubyte',
label = '/home/zhaopace/MXNet/mxnet/example/adversary/data/t10k-labels-idx1-ubyte',
batch_size = 128,
data_shape = (784, )
)

Second: 符号式编程, 生成一个两层的MLP

# Declare a two-layer MLP
data = mx.symbol.Variable('data') # data layer
fc1 = mx.symbol.FullyConnected(data=data, num_hidden=128) # full connected layer 1
act1 = mx.symbol.Activation(data=fc1, act_type="relu") # activation layer(relu activation function)
fc2 = mx.symbol.FullyConnected(data=act1, num_hidden=64)
act2 = mx.symbol.Activation(data=fc2, act_type="relu")
fc3 = mx.symbol.FullyConnected(data=act2, num_hidden=10)
mlp = mx.symbol.SoftmaxOutput(data=fc3, name="softmax") # Softmax layer

一个CNN网络最基本的几层:

输入层: mx.symbol.Variable()

激活层: mx.symbol.Activation()

Batch正则化: mx.symbol.BatchNorm()

Dropout: mx.symbol.Dropout()

全连接层: mx.symbol.FullyConnected()

池化层: mx.symbol.Pooling()

卷积层: mx.symbol.Convolution()

Softmax输出: mx.symbol.SoftmaxOutput()

LRN: mx.symbol.LRN()

......

mx.symbol.FullyConnected(*args, **kwargs)

功能: 对input作矩阵乘法, 并且加上一个偏置. 将shape为(batch_size, input_dim)的input变成(batch_size, num_hidden)的输出;

输入参数:

  • data:  Symbol类型, 输入数据;
  • weight:  Symbol类型, 权重矩阵;
  • bias:  Symbol类型, 偏置参数;
  • num_hidden: int型, 必要参数, 隐层节点的数目;
  • no_bias: 布尔型, 可选参数, defalut=False, 表示是否不要偏置参数
  • name:  字符串类型, 可选参数, 计算结果symbol的名称;

输出参数:

  • 输出是一个Symbol: the result symbol

Last: 训练以及测试

# Type: mxnet.model.FeedForward
# Train a model on the data
model = mx.model.FeedForward(
symbol = mlp,
num_epoch = 20,
learning_rate = .1
)
model.fit(X = train, eval_data = val) # Predict
model.predict(X = train)

class mxnet.model.FeedForward(sklearn.base.BaseEstimator)

输入参数:

  • symbol: Symbol类型, 网络的symbol结构配置;
  • ctx:
  • num_epoch: int型, 可选参数,是一个训练参数, 训练的迭代次数;
  • epoch_size: 一次epoch使用的batches数目, 默认情况下为(num_train_examples / batch_size)
  • optimizer:q
  • initializer:
  • numpy_batch_size:
  • ......

MXNet官网案例分析--Train MLP on MNIST

图2 mxnet.model函数列表