DeepLearnToolbox DBN源码解析

时间:2022-12-14 13:04:55

这几天看了下DeepLearnToolbox的源码,在此记录一下自己对DBN代码的理解。


test_example_DBN.m: 测试代码

function test_example_DBN
load ../data/mnist_40000_10000;
addpath('../DBN');
addpath('../NN');
addpath('../util');
train_x = double(train_x) / 255;
test_x  = double(test_x)  / 255;
train_y = double(train_y);
test_y  = double(test_y);

rand('state',0)
//train dbn
dbn.sizes = [100 200]; //DBN的结构,v1层为raw pixel/原始图片,h1/v2层的节点数为100,h2/v3层的节点数为200
opts.numepochs =   3;
opts.batchsize = 100;
opts.momentum  =   0; //记录以前的更新方向,并与现在的方向结合下,从而加快学习的速度
opts.alpha     =   1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);

%unfold dbn to nn
nn = dbnunfoldtonn(dbn, 10);
nn.activation_function = 'sigm';

//train nn
//得到DBN的初始化参数后,用nn进行微调
opts.numepochs =  3;
opts.batchsize = 100;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);

assert(er < 0.10, 'Too big error');



dbnsetup.m:建立DBN网络

function dbn = dbnsetup(dbn, x, opts)
    n = size(x, 2);
    dbn.sizes = [n, dbn.sizes]; //[784, 100,200]
    // 初始化W,b,c
    for u = 1 : numel(dbn.sizes) - 1
        dbn.rbm{u}.alpha    = opts.alpha;
        dbn.rbm{u}.momentum = opts.momentum;

        dbn.rbm{u}.W  = zeros(dbn.sizes(u + 1), dbn.sizes(u));
        dbn.rbm{u}.vW = zeros(dbn.sizes(u + 1), dbn.sizes(u));

        dbn.rbm{u}.b  = zeros(dbn.sizes(u), 1); //可视层的偏置bias
        dbn.rbm{u}.vb = zeros(dbn.sizes(u), 1);

        dbn.rbm{u}.c  = zeros(dbn.sizes(u + 1), 1); //隐层的偏置bias
        dbn.rbm{u}.vc = zeros(dbn.sizes(u + 1), 1);
    end

end

dbntrain.m:训练DBN

function dbn = dbntrain(dbn, x, opts)
    n = numel(dbn.rbm);


    dbn.rbm{1} = rbmtrain(dbn.rbm{1}, x, opts);
    for i = 2 : n
        x = rbmup(dbn.rbm{i - 1}, x);  // 即sigm(W*x+c)
        dbn.rbm{i} = rbmtrain(dbn.rbm{i}, x, opts);
    end


end


rbmtrain.m:训练RBM

采用对比散度(Contrastive Divergence,CD)算法进行训练,这是Hinton在2002年提出了RBM的一个快速学习算法
算法描述在 《Learning Deep Architectures for AI》 Algorithm 1,主要流程如下:

DeepLearnToolbox DBN源码解析

function rbm = rbmtrain(rbm, x, opts)
    assert(isfloat(x), 'x must be a float');
    assert(all(x(:)>=0) && all(x(:)<=1), 'all data in x must be in [0:1]');
    m = size(x, 1);
    numbatches = m / opts.batchsize;
    
    assert(rem(numbatches, 1) == 0, 'numbatches not integer');

    for i = 1 : opts.numepochs //迭代次数  
        kk = randperm(m); //将样本随机打乱
        err = 0;
        for l = 1 : numbatches
            batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);
            
            v1 = batch;
            h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W');   // Gibbs采样
            v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W);    // Gibbs采样
            h2 = sigm(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W');      // sigm(W*v2+c)
            // 对比上述流程图
            c1 = h1' * v1;
            c2 = h2' * v2;
           
            // rbm.momentum:记录以前的更新方向,并与现在的方向结合,从而加快学习速度    
            rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2)     / opts.batchsize;
            rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize;
            rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize;

            rbm.W = rbm.W + rbm.vW;
            rbm.b = rbm.b + rbm.vb;
            rbm.c = rbm.c + rbm.vc;

           err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;
        end
        
        disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)  '. Average reconstruction error is: ' num2str(err / numbatches)]);
        
    end
end     


dbnunfoldtonn.m:利用DBN的参数去初始化NN,然后用NN进行微调nn = nntrain(nn, train_x, train_y, opts);

function nn = dbnunfoldtonn(dbn, outputsize)
//   DBNUNFOLDTONN Unfolds a DBN to a NN
//   dbnunfoldtonn(dbn, outputsize ) returns the unfolded dbn with a final layer of size outputsize added.
    if(exist('outputsize','var'))
        size = [dbn.sizes outputsize];
    else
        size = [dbn.sizes];
    end
    nn = nnsetup(size);
    for i = 1 : numel(dbn.rbm)
        nn.W{i} = [dbn.rbm{i}.c dbn.rbm{i}.W]; //利用DBN每层的W和c去初始化NN的参数
    end
end


CNN源码解析http://blog.csdn.net/zouxy09/article/details/9993743   

                           http://blog.csdn.net/dark_scope/article/details/9495505






Reference:
(1)   Learning Deep Architectures for AI
(2)   A Practical Guide to Training Restricted Boltzmann Machines2010