k-均值算法的java实现

时间:2022-01-30 23:22:17

 

  1. import java.io.BufferedReader;   
  2. import java.io.FileNotFoundException;   
  3. import java.io.FileReader;   
  4. import java.io.IOException;   
  5.   
  6. public class KAverage {   
  7.     private int sampleCount = 0;   
  8.     private int dimensionCount = 0;   
  9.     private int centerCount = 0;   
  10.     private double[][] sampleValues;   
  11.     private double[][] centers;   
  12.     private double[][] tmpCenters;   
  13.     private String dataFile = "";   
  14.   
  15.     /**  
  16.      * 通过构造器传人数据文件  
  17.      */  
  18.     public KAverage(String dataFile) throws NumberInvalieException {   
  19.         this.dataFile = dataFile;   
  20.     }   
  21.   
  22.     /**  
  23.      * 第一行为s;d;c含义分别为样例的数目,每个样例特征的维数,聚类中心个数 文件格式为d[,d]...;d[,d]... 如:1,2;2,3;1,5  
  24.      * 每一维之间用,隔开,每个样例间用;隔开。结尾没有';' 可以有多行  
  25.      */  
  26.   
  27.     private int initData(String fileName) {   
  28.         String line;   
  29.         String samplesValue[];   
  30.         String dimensionsValue[] = new String[dimensionCount];   
  31.         BufferedReader in;   
  32.         try {   
  33.             in = new BufferedReader(new FileReader(fileName));   
  34.         } catch (FileNotFoundException e) {   
  35.             e.printStackTrace();   
  36.             return -1;   
  37.         }   
  38.         /*  
  39.          * 预处理样本,允许后面几维为0时,不写入文件  
  40.          */  
  41.         for (int i = 0; i < sampleCount; i++) {   
  42.             for (int j = 0; j < dimensionCount; j++) {   
  43.                 sampleValues[i][j] = 0;   
  44.             }   
  45.         }   
  46.   
  47.         int i = 0;   
  48.         double tmpValue = 0.0;   
  49.         try {   
  50.             line = in.readLine();   
  51.             String params[] = line.split(";");   
  52.             if (params.length != 3) {// 必须为3个参数,否则错误   
  53.                 return -1;   
  54.             }   
  55.             /**  
  56.              * 获取参数  
  57.              */  
  58.             this.sampleCount = Integer.parseInt(params[0]);   
  59.             this.dimensionCount = Integer.parseInt(params[1]);   
  60.             this.centerCount = Integer.parseInt(params[2]);   
  61.             if (sampleCount <= 0 || dimensionCount <= 0 || centerCount <= 0) {   
  62.                 throw new NumberInvalieException("input number <= 0.");   
  63.             }   
  64.             if (sampleCount < centerCount) {   
  65.                 throw new NumberInvalieException(   
  66.                         "sample number < center number");   
  67.             }   
  68.   
  69.             sampleValues = new double[sampleCount][dimensionCount + 1];   
  70.             centers = new double[centerCount][dimensionCount];   
  71.             tmpCenters = new double[centerCount][dimensionCount];   
  72.   
  73.             while ((line = in.readLine()) != null) {   
  74.                 samplesValue = line.split(";");   
  75.                 for (int j = 0; j < samplesValue.length; j++) {   
  76.                     dimensionsValue = samplesValue[j].split(",");   
  77.                     for (int k = 0; k < dimensionsValue.length; k++) {   
  78.                         tmpValue = Double.parseDouble(dimensionsValue[k]);   
  79.                         sampleValues[i][k] = tmpValue;   
  80.                     }   
  81.                     i++;   
  82.                 }   
  83.             }   
  84.   
  85.         } catch (IOException e) {   
  86.             e.printStackTrace();   
  87.             return -2;   
  88.         } catch (Exception e) {   
  89.             e.printStackTrace();   
  90.             return -3;   
  91.         }   
  92.         return 1;   
  93.     }   
  94.   
  95.     /**  
  96.      * 返回样本中第s1个和第s2个间的欧式距离  
  97.      */  
  98.     private double getDistance(int s1, int s2) throws NumberInvalieException {   
  99.         double distance = 0.0;   
  100.         if (s1 < 0 || s1 >= sampleCount || s2 < 0 || s2 >= sampleCount) {   
  101.             throw new NumberInvalieException("number out of bound.");   
  102.         }   
  103.         for (int i = 0; i < dimensionCount; i++) {   
  104.             distance += (sampleValues[s1][i] - sampleValues[s2][i])   
  105.                     * (sampleValues[s1][i] - sampleValues[s2][i]);   
  106.         }   
  107.   
  108.         return distance;   
  109.     }   
  110.   
  111.     /**  
  112.      * 返回给定两个向量间的欧式距离  
  113.      */  
  114.     private double getDistance(double s1[], double s2[]) {   
  115.         double distance = 0.0;   
  116.         for (int i = 0; i < dimensionCount; i++) {   
  117.             distance += (s1[i] - s2[i]) * (s1[i] - s2[i]);   
  118.         }   
  119.         return distance;   
  120.     }   
  121.   
  122.     /**  
  123.      * 更新样本中第s个样本的最近中心  
  124.      */  
  125.     private int getNearestCenter(int s) {   
  126.         int center = 0;   
  127.         double minDistance = Double.MAX_VALUE;   
  128.         double distance = 0.0;   
  129.         for (int i = 0; i < centerCount; i++) {   
  130.             distance = getDistance(sampleValues[s], centers[i]);   
  131.             if (distance < minDistance) {   
  132.                 minDistance = distance;   
  133.                 center = i;   
  134.             }   
  135.         }   
  136.         sampleValues[s][dimensionCount] = center;   
  137.         return center;   
  138.     }   
  139.   
  140.     /**  
  141.      * 更新所有中心  
  142.      */  
  143.     private void updateCenters() {   
  144.         double center[] = new double[dimensionCount];   
  145.         for (int i = 0; i < dimensionCount; i++) {   
  146.             center[i] = 0;   
  147.         }   
  148.         int count = 0;   
  149.         for (int i = 0; i < centerCount; i++) {   
  150.             count = 0;   
  151.             for (int j = 0; j < sampleCount; j++) {   
  152.                 if (sampleValues[j][dimensionCount] == i) {   
  153.                     count++;   
  154.                     for (int k = 0; k < dimensionCount; k++) {   
  155.                         center[k] += sampleValues[j][k];   
  156.                     }   
  157.                 }   
  158.             }   
  159.             for (int j = 0; j < dimensionCount; j++) {   
  160.                 centers[i][j] = center[j] / count;   
  161.             }   
  162.         }   
  163.     }   
  164.   
  165.     /**  
  166.      * 判断算法是否终止  
  167.      */  
  168.     private boolean toBeContinued() {   
  169.         for (int i = 0; i < centerCount; i++) {   
  170.             for (int j = 0; j < dimensionCount; j++) {   
  171.                 if (tmpCenters[i][j] != centers[i][j]) {   
  172.                     return true;   
  173.                 }   
  174.             }   
  175.         }   
  176.         return false;   
  177.     }   
  178.   
  179.     /**  
  180.      * 关键方法,调用其他方法,处理数据  
  181.      */  
  182.     public void doCaculate() {   
  183.         initData(dataFile);   
  184.   
  185.         for (int i = 0; i < centerCount; i++) {   
  186.             for (int j = 0; j < dimensionCount; j++) {   
  187.                 centers[i][j] = sampleValues[i][j];   
  188.             }   
  189.         }   
  190.         for (int i = 0; i < centerCount; i++) {   
  191.             for (int j = 0; j < dimensionCount; j++) {   
  192.                 tmpCenters[i][j] = 0;   
  193.             }   
  194.         }   
  195.   
  196.         while (toBeContinued()) {   
  197.             for (int i = 0; i < sampleCount; i++) {   
  198.                 getNearestCenter(i);   
  199.             }   
  200.             for (int i = 0; i < centerCount; i++) {   
  201.                 for (int j = 0; j < dimensionCount; j++) {   
  202.                     tmpCenters[i][j] = centers[i][j];   
  203.                 }   
  204.             }   
  205.             updateCenters();   
  206.             System.out   
  207.                     .println("******************************************************");   
  208.             showResultData();   
  209.         }   
  210.     }   
  211.   
  212.     /*  
  213.      * 显示数据  
  214.      */  
  215.     private void showSampleData() {   
  216.         for (int i = 0; i < sampleCount; i++) {   
  217.             for (int j = 0; j < dimensionCount; j++) {   
  218.                 if (j == 0) {   
  219.                     System.out.print(sampleValues[i][j]);   
  220.                 } else {   
  221.                     System.out.print("," + sampleValues[i][j]);   
  222.                 }   
  223.             }   
  224.             System.out.println();   
  225.         }   
  226.     }   
  227.   
  228.     /*  
  229.      * 分组显示结果  
  230.      */  
  231.     private void showResultData() {   
  232.         for (int i = 0; i < centerCount; i++) {   
  233.             System.out.println("第" + (i + 1) + "个分组内容为:");   
  234.             for (int j = 0; j < sampleCount; j++) {   
  235.                 if (sampleValues[j][dimensionCount] == i) {   
  236.                     for (int k = 0; k <= dimensionCount; k++) {   
  237.                         if (k == 0) {   
  238.                             System.out.print(sampleValues[j][k]);   
  239.                         } else {   
  240.                             System.out.print("," + sampleValues[j][k]);   
  241.                         }   
  242.                     }   
  243.                     System.out.println();   
  244.                 }   
  245.             }   
  246.         }   
  247.     }   
  248.   
  249.     public static void main(String[] args) {   
  250.         /*  
  251.          *也可以通过命令行得到参数  
  252.          */  
  253.         String fileName = "D://eclipsejava//K-Average//src//sample.txt";   
  254.         if(args.length > 0){   
  255.             fileName = args[0];   
  256.         }   
  257.            
  258.         try {   
  259.             KAverage ka = new KAverage(fileName);   
  260.             ka.doCaculate();   
  261.             System.out   
  262.                     .println("***************************<<result>>**************************");   
  263.             ka.showResultData();   
  264.         } catch (Exception e) {   
  265.             e.printStackTrace();   
  266.         }   
  267.     }   
  268. }  

 

Java代码 k-均值算法的java实现
  1.   
  2. /*  
  3.  * 根据自己的需要定义一些异常,使得系统性更强  
  4.  */  
  5. public class NumberInvalieException extends Exception {   
  6.     private String cause;   
  7.        
  8.     public NumberInvalieException(String cause){   
  9.         if(cause == null || "".equals(cause)){   
  10.             this.cause = "unknow";   
  11.         }else{   
  12.             this.cause = cause;   
  13.         }   
  14.     }   
  15.     @Override  
  16.     public String toString() {   
  17.         return "Number Invalie!Cause by " + cause;   
  18.     }   
  19. }  



测试数据
20;2;4
0,0;1,0;0,1;1,1;2,1;1,2;2,2;3,2;6,6;7,6
8,6;6,7;7,7;8,7;9,7;7,8;8,8;9,8;8,9;9,9
测试结果
***************************<<result>>**************************
第1个分组内容为:
0.0,0.0,0.0
1.0,0.0,0.0
0.0,1.0,0.0
1.0,1.0,0.0
2.0,1.0,0.0
1.0,2.0,0.0
2.0,2.0,0.0
3.0,2.0,0.0
第2个分组内容为:
6.0,6.0,1.0
7.0,6.0,1.0
8.0,6.0,1.0
6.0,7.0,1.0
7.0,7.0,1.0
8.0,7.0,1.0
9.0,7.0,1.0
7.0,8.0,1.0
8.0,8.0,1.0
9.0,8.0,1.0
8.0,9.0,1.0
9.0,9.0,1.0