Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3

时间:2022-09-12 19:00:49

Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3

http://blog.csdn.net/sunbow0

第二章Deep Belief Network (深度信念网络)

3实例

3.1 测试数据

按照上例数据,或者新建图片识别数据。

3.2 DBN实例

//****************2(读取固定样本:来源于经典优化算法测试函数Sphere Model***********//

    //2 读取样本数据

    Logger.getRootLogger.setLevel(Level.WARN)

    valdata_path ="/user/huangmeiling/deeplearn/data1"

    valexamples =sc.textFile(data_path).cache()

    valtrain_d1 =examples.map { line =>

      valf1 = line.split("\t")

      valf =f1.map(f =>f.toDouble)

      valid =f(0)

      valy = Array(f(1))

      valx =f.slice(2,f.length)

      (id, new BDM(1,y.length,y),new BDM(1,x.length,x))

    }

    valtrain_d =train_d1.map(f => (f._2, f._3))

    valopts = Array(100.0,20.0,0.0) 

    //3 设置训练参数,建立DBN模型

    valDBNmodel =new DBN().

      setSize(Array(5, 7)).

      setLayer(2).

      setMomentum(0.1).

      setAlpha(1.0).

      DBNtrain(train_d, opts) 

    //4 DBN模型转化为NN模型

    valmynn =DBNmodel.dbnunfoldtonn(1)

    valnnopts = Array(100.0,50.0,0.0)

    valnumExamples =train_d.count()

    println(s"numExamples = $numExamples.")

    println(mynn._2)

    for (i <-0 tomynn._1.length -1) {

      print(mynn._1(i) +"\t")

    }

    println()

    println("mynn_W1")

    valtmpw1 =mynn._3(0)

    for (i <-0 totmpw1.rows -1) {

      for (j <-0 totmpw1.cols -1) {

        print(tmpw1(i,j) +"\t")

      }

      println()

    }

    valNNmodel =new NeuralNet().

      setSize(mynn._1).

      setLayer(mynn._2).

      setActivation_function("sigm").

      setOutput_function("sigm").

      setInitW(mynn._3).

      NNtrain(train_d, nnopts) 

    //5 NN模型测试

    valNNforecast =NNmodel.predict(train_d)

    valNNerror =NNmodel.Loss(NNforecast)

    println(s"NNerror = $NNerror.")

    valprintf1 =NNforecast.map(f => (f.label.data(0), f.predict_label.data(0))).take(200)

    println("预测结果——实际值:预测值:误差")

    for (i <-0 untilprintf1.length)

      println(printf1(i)._1 +"\t" +printf1(i)._2 +"\t" + (printf1(i)._2 -printf1(i)._1)) 

转载请注明出处:

http://blog.csdn.net/sunbow0