K均值聚类算法的Java版实现代码示例

时间:2022-06-06 03:02:17

1.简介

K均值聚类算法是先随机选取K个对象作为初始的聚类中心。然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。一旦全部对象都被分配了,每个聚类的聚类中心会根据聚类中现有的对象被重新计算。这个过程将不断重复直到满足某个终止条件。终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。

2.什么是聚类

聚类是一个将数据集中在某些方面相似的数据成员进行分类组织的过程,聚类就是一种发现这种内在结构的技术,聚类技术经常被称为无监督学习。

3.什么是k均值聚类

k均值聚类是最著名的划分聚类算法,由于简洁和效率使得他成为所有聚类算法中最广泛使用的。给定一个数据点集合和需要的聚类数目k,k由用户指定,k均值算法根据某个距离函数反复把数据分入k个聚类中。

4.实现

Java代码如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
package org.algorithm;
import java.util.ArrayList;
import java.util.Random;
/**
 * K均值聚类算法
 */
public class Kmeans {
    private int k;
    // 分成多少簇
    private int m;
    // 迭代次数
    private int dataSetLength;
    // 数据集元素个数,即数据集的长度
    private ArrayList<float[]> dataSet;
    // 数据集链表
    private ArrayList<float[]> center;
    // 中心链表
    private ArrayList<ArrayList<float[]>> cluster;
    // 簇
    private ArrayList<float> jc;
    // 误差平方和,k越接近dataSetLength,误差越小
    private Random random;
    /**
   * 设置需分组的原始数据集
   *
   * @param dataSet
   */
    public void setDataSet(ArrayList<float[]> dataSet) {
        this.dataSet = dataSet;
    }
    /**
   * 获取结果分组
   *
   * @return 结果集
   */
    public ArrayList<ArrayList<float[]>> getCluster() {
        return cluster;
    }
    /**
   * 构造函数,传入需要分成的簇数量
   *
   * @param k
   *      簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度
   */
    public Kmeans(int k) {
        if (k <= 0) {
            k = 1;
        }
        this.k = k;
    }
    /**
   * 初始化
   */
    private void init() {
        m = 0;
        random = new Random();
        if (dataSet == null || dataSet.size() == 0) {
            initDataSet();
        }
        dataSetLength = dataSet.size();
        if (k > dataSetLength) {
            k = dataSetLength;
        }
        center = initCenters();
        cluster = initCluster();
        jc = new ArrayList<float>();
    }
    /**
   * 如果调用者未初始化数据集,则采用内部测试数据集
   */
    private void initDataSet() {
        dataSet = new ArrayList<float[]>();
        // 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
        float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
                { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
                { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };
        for (int i = 0; i < dataSetArray.length; i++) {
            dataSet.add(dataSetArray[i]);
        }
    }
    /**
   * 初始化中心数据链表,分成多少簇就有多少个中心点
   *
   * @return 中心点集
   */
    private ArrayList<float[]> initCenters() {
        ArrayList<float[]> center = new ArrayList<float[]>();
        int[] randoms = new int[k];
        Boolean flag;
        int temp = random.nextint(dataSetLength);
        randoms[0] = temp;
        for (int i = 1; i < k; i++) {
            flag = true;
            while (flag) {
                temp = random.nextint(dataSetLength);
                int j = 0;
                // 不清楚for循环导致j无法加1
                // for(j=0;j<i;++j)
                // {
                // if(temp==randoms[j]);
                // {
                // break;
                // }
                // }
                while (j < i) {
                    if (temp == randoms[j]) {
                        break;
                    }
                    j++;
                }
                if (j == i) {
                    flag = false;
                }
            }
            randoms[i] = temp;
        }
        // 测试随机数生成情况
        // for(int i=0;i<k;i++)
        // {
        // System.out.println("test1:randoms["+i+"]="+randoms[i]);
        // }
        // System.out.println();
        for (int i = 0; i < k; i++) {
            center.add(dataSet.get(randoms[i]));
            // 生成初始化中心链表
        }
        return center;
    }
    /**
   * 初始化簇集合
   *
   * @return 一个分为k簇的空数据的簇集合
   */
    private ArrayList<ArrayList<float[]>> initCluster() {
        ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
        for (int i = 0; i < k; i++) {
            cluster.add(new ArrayList<float[]>());
        }
        return cluster;
    }
    /**
   * 计算两个点之间的距离
   *
   * @param element
   *      点1
   * @param center
   *      点2
   * @return 距离
   */
    private float distance(float[] element, float[] center) {
        float distance = 0.0f;
        float x = element[0] - center[0];
        float y = element[1] - center[1];
        float z = x * x + y * y;
        distance = (float) Math.sqrt(z);
        return distance;
    }
    /**
   * 获取距离集合中最小距离的位置
   *
   * @param distance
   *      距离数组
   * @return 最小距离在距离数组中的位置
   */
    private int minDistance(float[] distance) {
        float minDistance = distance[0];
        int minLocation = 0;
        for (int i = 1; i < distance.length; i++) {
            if (distance[i] < minDistance) {
                minDistance = distance[i];
                minLocation = i;
            } else if (distance[i] == minDistance) // 如果相等,随机返回一个位置
            {
                if (random.nextint(10) < 5) {
                    minLocation = i;
                }
            }
        }
        return minLocation;
    }
    /**
   * 核心,将当前元素放到最小距离中心相关的簇中
   */
    private void clusterSet() {
        float[] distance = new float[k];
        for (int i = 0; i < dataSetLength; i++) {
            for (int j = 0; j < k; j++) {
                distance[j] = distance(dataSet.get(i), center.get(j));
                // System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]);
            }
            int minLocation = minDistance(distance);
            // System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation);
            // System.out.println();
            cluster.get(minLocation).add(dataSet.get(i));
            // 核心,将当前元素放到最小距离中心相关的簇中
        }
    }
    /**
   * 求两点误差平方的方法
   *
   * @param element
   *      点1
   * @param center
   *      点2
   * @return 误差平方
   */
    private float errorSquare(float[] element, float[] center) {
        float x = element[0] - center[0];
        float y = element[1] - center[1];
        float errSquare = x * x + y * y;
        return errSquare;
    }
    /**
   * 计算误差平方和准则函数方法
   */
    private void countRule() {
        float jcF = 0;
        for (int i = 0; i < cluster.size(); i++) {
            for (int j = 0; j < cluster.get(i).size(); j++) {
                jcF += errorSquare(cluster.get(i).get(j), center.get(i));
            }
        }
        jc.add(jcF);
    }
    /**
   * 设置新的簇中心方法
   */
    private void setNewCenter() {
        for (int i = 0; i < k; i++) {
            int n = cluster.get(i).size();
            if (n != 0) {
                float[] newCenter = { 0, 0 };
                for (int j = 0; j < n; j++) {
                    newCenter[0] += cluster.get(i).get(j)[0];
                    newCenter[1] += cluster.get(i).get(j)[1];
                }
                // 设置一个平均值
                newCenter[0] = newCenter[0] / n;
                newCenter[1] = newCenter[1] / n;
                center.set(i, newCenter);
            }
        }
    }
    /**
   * 打印数据,测试用
   *
   * @param dataArray
   *      数据集
   * @param dataArrayName
   *      数据集名称
   */
    public void printDataArray(ArrayList<float[]> dataArray,
          String dataArrayName) {
        for (int i = 0; i < dataArray.size(); i++) {
            System.out.println("print:" + dataArrayName + "[" + i + "]={"
                      + dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
        }
        System.out.println("===================================");
    }
    /**
   * Kmeans算法核心过程方法
   */
    private void kmeans() {
        init();
        // printDataArray(dataSet,"initDataSet");
        // printDataArray(center,"initCenter");
        // 循环分组,直到误差不变为止
        while (true) {
            clusterSet();
            // for(int i=0;i<cluster.size();i++)
            // {
            // printDataArray(cluster.get(i),"cluster["+i+"]");
            // }
            countRule();
            // System.out.println("count:"+"jc["+m+"]="+jc.get(m));
            // System.out.println();
            // 误差不变了,分组完成
            if (m != 0) {
                if (jc.get(m) - jc.get(m - 1) == 0) {
                    break;
                }
            }
            setNewCenter();
            // printDataArray(center,"newCenter");
            m++;
            cluster.clear();
            cluster = initCluster();
        }
        // System.out.println("note:the times of repeat:m="+m);//输出迭代次数
    }
    /**
   * 执行算法
   */
    public void execute() {
        long startTime = System.currentTimeMillis();
        System.out.println("kmeans begins");
        kmeans();
        long endTime = System.currentTimeMillis();
        System.out.println("kmeans running time=" + (endTime - startTime)
                + "ms");
        System.out.println("kmeans ends");
        System.out.println();
    }
}

5.说明:

具体代码是从网上找的,根据自己的理解加了注释和进行部分修改,若注释有误还望指正

6.测试

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
package org.test;
import java.util.ArrayList;
import org.algorithm.Kmeans;
public class KmeansTest {
    public static void main(String[] args)
      {
        //初始化一个Kmean对象,将k置为10
        Kmeans k=new Kmeans(10);
        ArrayList<float[]> dataSet=new ArrayList<float[]>();
        dataSet.add(new float[]{1,2});
        dataSet.add(new float[]{3,3});
        dataSet.add(new float[]{3,4});
        dataSet.add(new float[]{5,6});
        dataSet.add(new float[]{8,9});
        dataSet.add(new float[]{4,5});
        dataSet.add(new float[]{6,4});
        dataSet.add(new float[]{3,9});
        dataSet.add(new float[]{5,9});
        dataSet.add(new float[]{4,2});
        dataSet.add(new float[]{1,9});
        dataSet.add(new float[]{7,8});
        //设置原始数据集
        k.setDataSet(dataSet);
        //执行算法
        k.execute();
        //得到聚类结果
        ArrayList<ArrayList<float[]>> cluster=k.getCluster();
        //查看结果
        for (int i=0;i<cluster.size();i++)
            {
            k.printDataArray(cluster.get(i), "cluster["+i+"]");
        }
    }
}

总结:测试代码已经通过。并对聚类的结果进行了查看,结果基本上符合要求。至于有没有更精确的算法有待发现。具体的实践还有待挖掘

总结

以上就是本文关于K均值聚类算法的Java版实现代码示例的全部内容,希望对大家有所帮助。感兴趣的朋友可以继续参阅本站其他相关专题。如有不足之处,欢迎留言指出。感谢朋友们对本站的支持!

原文链接:http://blog.csdn.net/cyxlzzs/article/details/7416491