数据挖掘--kmeans聚类算法mapreduce实现 代码

时间:2021-03-05 23:41:29
分类: 数据挖掘算法 | 标签:  数据挖掘    kmeans    mapreduce  
2012-11-14 13:36阅读( 4148) 评论(1)
==================cluster.txt===========================
A    2    2
B    2    4
C    4    2
D    4    4
E    6    6
F    6    8
G    8    6
H    8    8
==================cluster.center.conf===========================
K1    3    2
K2    6    2
====================================================================================
package com.mahout.cluster;

//二维坐标的点
public class DmRecord {
    private String name;
    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    private double xpodouble;
    private double ypodouble;
    
    public DmRecord(){
        
    }
    
    public DmRecord(String name,double x,double y){
        this.name = name;
        this.xpodouble = x;
        this.ypodouble = y;
    }

    public double getXpoint() {
        return xpodouble;
    }

    public void setXpoint(double xpodouble) {
        this.xpodouble = xpodouble;
    }

    public double getYpoint() {
        return ypodouble;
    }

    public void setYpoint(double ypodouble) {
        this.ypodouble = ypodouble;
    }
    
    public  double distance(DmRecord record){
        return Math.sqrt(Math.pow(this.xpodouble-record.xpodouble, 2)+Math.pow(this.ypodouble-record.ypodouble, 2));
    }
}
==============================================================================
package com.mahout.cluster;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;

import org.apache.hadoop.io.IOUtils;


public class DmRecordParser {
    private Map<String,DmRecord> urlMap = new HashMap<String,DmRecord>();
      
      /**
       * 读取配置文件记录,生成对象
       */
      public void initialize(File file) throws IOException {
        BufferedReader in = null;
        try {
          in = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
          String line;
          while ((line = in.readLine()) != null) {
            String [] strKey = line.split("\t");
            urlMap.put(strKey[0],parse(line));
          }
        } finally {
          IOUtils.closeStream(in);
        }
      }
      
      /**
       * 生成坐标对象
       */
      public DmRecord parse(String line){
        String [] strPlate = line.split("\t");
        DmRecord Dmurl = new DmRecord(strPlate[0],Integer.parseInt(strPlate[1]),Integer.parseInt(strPlate[2]));
        return Dmurl;
      }
      
      /**
       * 获取分类中心坐标
       */
      public DmRecord getUrlCode(String cluster){
          DmRecord returnCode = null;
        DmRecord dmUrl = (DmRecord)urlMap.get(cluster);
        if(dmUrl == null){
          //35     6
            returnCode = null; 
        }else{
            returnCode =dmUrl;
        }
        return returnCode;
      }
}


==============================================================================
package com.mahout.cluster;

import java.io.File;
import java.io.IOException;
import java.util.Iterator;

import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.TextInputFormat;
import org.apache.hadoop.mapred.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

import com.mahout.test.StringStringPairAsce;


public class Kmeans  extends Configured implements Tool {

    public static class KmeansMapper extends MapReduceBase implements
            Mapper<LongWritable, Text, Text, Text> {
        private DmRecordParser drp ;
        private String clusterNode = "K";
        private DmRecord record0 = null;
        private DmRecord record1 = new DmRecord();
        private double Min_distance = 9999;
        private int tmpK = 0;
        private Text tKey = new Text();
        private Text tValue = new Text();
        
        //获取聚类中心坐标
        @Override
        public void configure(JobConf conf) {
            drp = new DmRecordParser();
            try {
                drp.initialize(new File("cluster.center.conf"));
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        
        //根据聚类坐标,把文件中的点进行类别划分
        @Override
        public void map(LongWritable key, Text value,
                OutputCollector<Text, Text> output, Reporter arg3)
                throws IOException {
            String [] strArr = value.toString().split("\t");
            
            for(int i=1; i <= 2; i++){
                record0 = drp.getUrlCode("K"+i);
                record1.setName(strArr[0]);
                record1.setXpoint(Double.parseDouble(strArr[1]));
                record1.setXpoint(Integer.parseInt(strArr[2]));
                
                if(record0.distance(record1) < Min_distance){
                    tmpK = i;
                    Min_distance = record0.distance(record1);
                }
            }            
            
            tKey.set("C"+tmpK);
            output.collect(tKey, value);
        }
    }
    
    //计算新的聚类中心
    public static class KmeansReducer extends MapReduceBase implements
            Reducer<Text, Text, Text, Text> {        
        private Text tKey = new Text();
        private Text tValue = new Text();
        
        @Override
        public void reduce(Text key, Iterator<Text> value,
                OutputCollector<Text, Text> output, Reporter arg3)
                throws IOException {
            double avgX=0;
            double avgY=0;
            double sumX=0;
            double sumY=0;
            int count=0;
            String [] strValue = null;
            
            while(value.hasNext()){
                count++;
                strValue = value.next().toString().split("\t");
                sumX = sumX + Integer.parseInt(strValue[1]);
                sumY = sumY + Integer.parseInt(strValue[1]);
            }
            
            avgX = sumX/count;
            avgY = sumY/count;
            tKey.set("K"+key.toString().substring(1,2));
            tValue.set(avgX + "\t" + avgY);
            System.out.println("K"+key.toString().substring(1,2)+"\t"+avgX + "\t" + avgY);
            output.collect(tKey, tValue);
        }
    }
    
    @Override
    public int run(String[] args) throws Exception {
        JobConf conf = new JobConf(getConf(), Kmeans.class);
        conf.setJobName("Kmeans");
        //conf.setNumMapTasks(200);

        // 设置Map输出的key和value的类型
        conf.setMapOutputKeyClass(Text.class);
        conf.setMapOutputValueClass(Text.class);

        // 设置Reduce输出的key和value的类型
        conf.setOutputKeyClass(Text.class);
        conf.setOutputValueClass(Text.class);

        // 设置Mapper和Reducer
        conf.setMapperClass(KmeansMapper.class);
        conf.setReducerClass(KmeansReducer.class);
        
        conf.setInputFormat(TextInputFormat.class);
        conf.setOutputFormat(TextOutputFormat.class);

        // 设置输入输出目录
        FileInputFormat.setInputPaths(conf, new Path(args[0]));
        FileOutputFormat.setOutputPath(conf, new Path(args[1]));

        JobClient.runJob(conf);
        return 0;
    }

    public static void main(String[] args) throws Exception {
        int exitCode = ToolRunner.run(new Kmeans(), args);
        System.exit(exitCode);
    }
}