基于Java实现的一层简单人工神经网络算法示例

时间:2021-10-25 01:10:44

本文实例讲述了基于java实现的一层简单人工神经网络算法。分享给大家供大家参考,具体如下:

先来看看笔者绘制的算法图:

基于Java实现的一层简单人工神经网络算法示例

基于Java实现的一层简单人工神经网络算法示例

2、数据类

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import java.util.arrays;
public class data {
  double[] vector;
  int dimention;
  int type;
  public double[] getvector() {
    return vector;
  }
  public void setvector(double[] vector) {
    this.vector = vector;
  }
  public int getdimention() {
    return dimention;
  }
  public void setdimention(int dimention) {
    this.dimention = dimention;
  }
  public int gettype() {
    return type;
  }
  public void settype(int type) {
    this.type = type;
  }
  public data(double[] vector, int dimention, int type) {
    super();
    this.vector = vector;
    this.dimention = dimention;
    this.type = type;
  }
  public data() {
  }
  @override
  public string tostring() {
    return "data [vector=" + arrays.tostring(vector) + ", dimention=" + dimention + ", type=" + type + "]";
  }
}

3、简单人工神经网络

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
package cn.edu.hbut.chenjie;
import java.util.arraylist;
import java.util.list;
import java.util.random;
import org.jfree.chart.chartfactory;
import org.jfree.chart.chartframe;
import org.jfree.chart.jfreechart;
import org.jfree.data.xy.defaultxydataset;
import org.jfree.ui.refineryutilities;
public class ann2 {
  private double eta;//学习率
  private int n_iter;//权重向量w[]训练次数
  private list<data> exercise;//训练数据集
  private double w0 = 0;//阈值
  private double x0 = 1;//固定值
  private double[] weights;//权重向量,其长度为训练数据维度+1,在本例中数据为2维,故长度为3
  private int testsum = 0;//测试数据总数
  private int error = 0;//错误次数
  defaultxydataset xydataset = new defaultxydataset();
  /**
   * 向图表中增加同类型的数据
   * @param type 类型
   * @param a 所有数据的第一个分量
   * @param b 所有数据的第二个分量
   */
  public void add(string type,double[] a,double[] b)
  {
    double[][] data = new double[2][a.length];
    for(int i=0;i<a.length;i++)
    {
      data[0][i] = a[i];
      data[1][i] = b[i];
    }
    xydataset.addseries(type, data);
  }
  /**
   * 画图
   */
  public void draw()
  {
    jfreechart jfreechart = chartfactory.createscatterplot("exercise", "x1", "x2", xydataset);
    chartframe frame = new chartframe("训练数据", jfreechart);
    frame.pack();
    refineryutilities.centerframeonscreen(frame);
    frame.setvisible(true);
  }
  public static void main(string[] args)
  {
    ann2 ann2 = new ann2(0.001,100);//构造人工神经网络
    list<data> exercise = new arraylist<data>();//构造训练集
    //人工模拟1000条训练数据 ,分界线为x2=x1+0.5
    for(int i=0;i<1000000;i++)
    {
      random rd = new random();
      double x1 = rd.nextdouble();//随机产生一个分量
      double x2 = rd.nextdouble();//随机产生另一个分量
      double[] da = {x1,x2};//产生数据向量
      data d = new data(da, 2, x2 > x1+0.5 ? 1 : -1);//构造数据
      exercise.add(d);//将训练数据加入训练集
    }
    int sum1 = 0;//记录类型1的训练记录数
    int sum2 = 0;//记录类型-1的训练记录数
    for(int i = 0; i < exercise.size(); i++)
    {
      if(exercise.get(i).gettype()==1)
        sum1++;
      else if(exercise.get(i).gettype()==-1)
        sum2++;
    }
    double[] x1 = new double[sum1];
    double[] y1 = new double[sum1];
    double[] x2 = new double[sum2];
    double[] y2 = new double[sum2];
    int index1 = 0;
    int index2 = 0;
    for(int i = 0; i < exercise.size(); i++)
    {
      if(exercise.get(i).gettype()==1)
      {
        x1[index1] = exercise.get(i).vector[0];
        y1[index1++] = exercise.get(i).vector[1];
      }
      else if(exercise.get(i).gettype()==-1)
      {
        x2[index2] = exercise.get(i).vector[0];
        y2[index2++] = exercise.get(i).vector[1];
      }
    }
    ann2.add("1", x1, y1);
    ann2.add("-1", x2, y2);
    ann2.draw();
    ann2.input(exercise);//将训练集输入人工神经网络
    ann2.fit();//训练
    ann2.showweigths();//显示权重向量
    //人工生成一千条测试数据
    for(int i=0;i<10000;i++)
    {
      random rd = new random();
      double x1_ = rd.nextdouble();
      double x2_ = rd.nextdouble();
      double[] da = {x1_,x2_};
      data test = new data(da, 2, x2_ > x1_+0.5 ? 1 : -1);
      ann2.predict(test);//测试
    }
    system.out.println("总共测试" + ann2.testsum + "条数据,有" + ann2.error + "条错误,错误率:" + ann2.error * 1.0 /ann2.testsum * 100 + "%");
  }
  /**
   *
   * @param eta 学习率
   * @param n_iter 权重分量学习次数
   */
  public ann2(double eta, int n_iter) {
    this.eta = eta;
    this.n_iter = n_iter;
  }
  /**
   * 输入训练集到人工神经网络
   * @param exercise
   */
  private void input(list<data> exercise) {
    this.exercise = exercise;//保存训练集
    weights = new double[exercise.get(0).dimention + 1];//初始化权重向量,其长度为训练数据维度+1
    weights[0] = w0;//权重向量第一个分量为w0
    for(int i = 1; i < weights.length; i++)
      weights[i] = 0;//其余分量初始化为0
  }
  private void fit() {
    for(int i = 0; i < n_iter; i++)//权重分量调整n_iter次
    {
      for(int j = 0; j < exercise.size(); j++)//对于训练集中的每条数据进行训练
      {
        int real_result = exercise.get(j).type;//y
        int calculate_result = calculateresult(exercise.get(j));//y'
        double delta0 = eta * (real_result - calculate_result);//计算阈值更新
        w0 += delta0;//阈值更新
        weights[0] = w0;//更新w[0]
        for(int k = 0; k < exercise.get(j).getdimention(); k++)//更新权重向量其它分量
        {
          double delta = eta * (real_result - calculate_result) * exercise.get(j).vector[k];
          //δw=η*(y-y')*x
          weights[k+1] += delta;
          //w=w+δw
        }
      }
    }
  }
  private int calculateresult(data data) {
    double z = w0 * x0;
    for(int i = 0; i < data.dimention; i++)
      z += data.vector[i] * weights[i+1];
    //z=w0x0+w1x1+...+wmxm
    //激活函数
    if(z>=0)
      return 1;
    else
      return -1;
  }
  private void showweigths()
  {
    for(double w : weights)
      system.out.println(w);
  }
  private void predict(data data) {
    int type = calculateresult(data);
    if(type == data.gettype())
    {
      //system.out.println("预测正确");
    }
    else
    {
      //system.out.println("预测错误");
      error ++;
    }
    testsum ++;
  }
}

运行结果:

?
1
2
3
4
-0.22000000000000017
-0.4416843982815453
0.442444202054685
总共测试10000条数据,有17条错误,错误率:0.16999999999999998%

基于Java实现的一层简单人工神经网络算法示例

希望本文所述对大家java程序设计有所帮助。

原文链接:http://blog.csdn.net/csj941227/article/details/73325695