MapReduce实现矩阵乘法

时间:2023-01-09 08:06:52

转自  http://blog.csdn.net/liuxinghao/article/details/39958957


简单回顾一下矩阵乘法:

MapReduce实现矩阵乘法

矩阵乘法要求左矩阵的列数与右矩阵的行数相等,m×n的矩阵A,与n×p的矩阵B相乘,结果为m×p的矩阵C。详细内容可以查看:矩阵乘法

为了方便描述,先进行假设:

  • 矩阵A的行数为m,列数为n,aij为矩阵A第i行j列的元素。
  • 矩阵B的行数为n,列数为p,bij为矩阵B第i行j列的元素。

分析

  因为分布式计算的特点,需要找到相互独立的计算过程,以便能够在不同的节点上进行计算而不会彼此影响。根据矩阵乘法的公式,C中各个元素的计算都是相互独立的,即各个cij在计算过程中彼此不影响。这样的话,在Map阶段可以把计算所需要的元素都集中到同一个key中,然后,在Reduce阶段就可以从中解析出各个元素来计算cij

  另外,以a11为例,它将会在c11、c12……c1p的计算中使用。也就是说,在Map阶段,当我们从HDFS取出一行记录时,如果该记录是A的元素,则需要存储成p个<key, value>对,并且这p个key互不相同;如果该记录是B的元素,则需要存储成m个<key, value>对,同样的,m个key也应互不相同;但同时,用于存放计算cij的ai1、ai2……ain和b1j、b2j……bnj的<key, value>对的key应该都是相同的,这样才能被传递到同一个Reduce中。

设计

  普遍有一个共识是:数据结构+算法=程序,所以在编写代码之前需要先理清数据存储结构和处理数据的算法。

算法

map阶段

  在map阶段,需要做的是进行数据准备。把来自矩阵A的元素aij,标识成p条<key, value>的形式,key="i,k",(其中k=1,2,...,p),value="a:j,aij";把来自矩阵B的元素bij,标识成m条<key, value>形式,key="k,j"(其中k=1,2,...,m),value="b:i,bij"。

  经过处理,用于计算cij需要的a、b就转变为有相同key("i,j")的数据对,通过value中"a:"、"b:"能区分元素是来自矩阵A还是矩阵B,以及具体的位置(在矩阵A的第几列,在矩阵B的第几行)。

shuffle阶段

  这个阶段是Hadoop自动完成的阶段,具有相同key的value被分到同一个Iterable中,形成<key,Iterable(value)>对,再传递给reduce。

reduce阶段

  通过map数据预处理和shuffle数据分组两个阶段,reduce阶段只需要知道两件事就行:

  • <key,Iterable(value)>对经过计算得到的是矩阵C的哪个元素?因为map阶段对数据的处理,key(i,j)中的数据对,就是其在矩阵C中的位置,第i行j列。
  • Iterable中的每个value来自于矩阵A和矩阵B的哪个位置?这个也在map阶段进行了标记,对于value(x:y,z),只需要找到y相同的来自不同矩阵(即x分别为a和b)的两个元素,取z相乘,然后加和即可。

数据结构

  计算过程已经设计清楚了,就需要对数据结构进行设计。大体有两种设计方案:

  第一种:使用最原始的表示方式,相同行内不同列数据通过","分割,不同行通过换行分割;

  第二种:通过行列表示法,即文件中的每行数据有三个元素通过分隔符分割,第一个元素表示行,第二个元素表示列,第三个元素表示数据。这种方式对于可以不列出为0的元素,即可以减少稀疏矩阵的数据量。

  MapReduce实现矩阵乘法

  在上图中,第一种方式存储的数据量小于第二种,但这只是因为例子中的数据设计成这样。在现实中,使用分布式计算矩阵乘法的环境中,大部分矩阵是稀疏矩阵,且数据量极大,在这种情况下,第二种数据结构的优势就显现了出来。而且,因为使用分布式计算,如果数据大于64m,在map阶段将不能够逐行处理,将不能确定数据来自于哪一行。不过,由于现实中对于大矩阵的乘法,考虑到存储空间和内存的情况,需要特殊的处理方式,有一种是将矩阵进行行列转换然后计算,这个时候第一种还是挺实用的。

编写代码

第一种数据结构

代码为:

[java]  view plain  copy
  1. import java.io.IOException;  
  2. import java.util.HashMap;  
  3. import java.util.Iterator;  
  4. import java.util.Map;  
  5.   
  6. import org.apache.hadoop.conf.Configuration;  
  7. import org.apache.hadoop.fs.Path;  
  8. import org.apache.hadoop.io.IntWritable;  
  9. import org.apache.hadoop.io.LongWritable;  
  10. import org.apache.hadoop.io.Text;  
  11. import org.apache.hadoop.mapreduce.Job;  
  12. import org.apache.hadoop.mapreduce.Mapper;  
  13. import org.apache.hadoop.mapreduce.Reducer;  
  14. import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;  
  15. import org.apache.hadoop.mapreduce.lib.input.FileSplit;  
  16. import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;  
  17. import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;  
  18. import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;  
  19.   
  20. /** 
  21.  * @author liuxinghao 
  22.  * @version 1.0 Created on 2014年10月9日 
  23.  */  
  24. public class MatrixMultiply {  
  25.     public static class MatrixMapper extends  
  26.             Mapper<LongWritable, Text, Text, Text> {  
  27.         private String flag = null;// 数据集名称  
  28.         private int rowNum = 4;// 矩阵A的行数  
  29.         private int colNum = 2;// 矩阵B的列数  
  30.         private int rowIndexA = 1// 矩阵A,当前在第几行  
  31.         private int rowIndexB = 1// 矩阵B,当前在第几行  
  32.   
  33.         @Override  
  34.         protected void setup(Context context) throws IOException,  
  35.                 InterruptedException {  
  36.             flag = ((FileSplit) context.getInputSplit()).getPath().getName();// 获取文件名称  
  37.         }  
  38.   
  39.         @Override  
  40.         protected void map(LongWritable key, Text value, Context context)  
  41.                 throws IOException, InterruptedException {  
  42.             String[] tokens = value.toString().split(",");  
  43.             if ("ma".equals(flag)) {  
  44.                 for (int i = 1; i <= colNum; i++) {  //循坏多列
  45.                     Text k = new Text(rowIndexA + "," + i);  
  46.                     for (int j = 0; j < tokens.length; j++) {  
  47.                         Text v = new Text("a," + (j + 1) + "," + tokens[j]);  
  48.                         context.write(k, v);  
  49.                     }  
  50.                 }  
  51.                 rowIndexA++;// 每执行一次map方法,矩阵向下移动一行  
  52.             } else if ("mb".equals(flag)) {  
  53.                 for (int i = 1; i <= rowNum; i++) {  //循环多行
  54.                     for (int j = 0; j < tokens.length; j++) {  
  55.                         Text k = new Text(i + "," + (j + 1));  
  56.                         Text v = new Text("b," + rowIndexB + "," + tokens[j]);  
  57.                         context.write(k, v);  
  58.                     }  
  59.                 }  
  60.                 rowIndexB++;// 每执行一次map方法,矩阵向下移动一行  
  61.             }  
  62.         }  
  63.     }  
  64.   
  65.     public static class MatrixReducer extends  
  66.             Reducer<Text, Text, Text, IntWritable> {  
  67.         @Override  
  68.         protected void reduce(Text key, Iterable<Text> values, Context context)  
  69.                 throws IOException, InterruptedException {  
  70.             Map<String, String> mapA = new HashMap<String, String>();  
  71.             Map<String, String> mapB = new HashMap<String, String>();  
  72.   
  73.             for (Text value : values) {  
  74.                 String[] val = value.toString().split(",");  
  75.                 if ("a".equals(val[0])) {  
  76.                     mapA.put(val[1], val[2]);  
  77.                 } else if ("b".equals(val[0])) {  
  78.                     mapB.put(val[1], val[2]);  
  79.                 }  
  80.             }  
  81.   
  82.             int result = 0;  
  83.             Iterator<String> mKeys = mapA.keySet().iterator();  
  84.             while (mKeys.hasNext()) {  
  85.                 String mkey = mKeys.next();  
  86.                 if (mapB.get(mkey) == null) {// 因为mkey取的是mapA的key集合,所以只需要判断mapB是否存在即可。  
  87.                     continue;  
  88.                 }  
  89.                 result += Integer.parseInt(mapA.get(mkey))  
  90.                         * Integer.parseInt(mapB.get(mkey));  
  91.             }  
  92.             context.write(key, new IntWritable(result));  
  93.         }  
  94.     }  
  95.   
  96.     public static void main(String[] args) throws IOException,  
  97.             ClassNotFoundException, InterruptedException {  
  98.         String input1 = "hdfs://192.168.1.128:9000/user/lxh/matrix/ma";  
  99.         String input2 = "hdfs://192.168.1.128:9000/user/lxh/matrix/mb";  
  100.         String output = "hdfs://192.168.1.128:9000/user/lxh/matrix/out";  
  101.   
  102.         Configuration conf = new Configuration();  
  103.         conf.addResource("classpath:/hadoop/core-site.xml");  
  104.         conf.addResource("classpath:/hadoop/hdfs-site.xml");  
  105.         conf.addResource("classpath:/hadoop/mapred-site.xml");  
  106.         conf.addResource("classpath:/hadoop/yarn-site.xml");  
  107.   
  108.         Job job = Job.getInstance(conf, "MatrixMultiply");  
  109.         job.setJarByClass(MatrixMultiply.class);  
  110.         job.setOutputKeyClass(Text.class);  
  111.         job.setOutputValueClass(Text.class);  
  112.   
  113.         job.setMapperClass(MatrixMapper.class);  
  114.         job.setReducerClass(MatrixReducer.class);  
  115.   
  116.         job.setInputFormatClass(TextInputFormat.class);  
  117.         job.setOutputFormatClass(TextOutputFormat.class);  
  118.   
  119.         FileInputFormat.setInputPaths(job, new Path(input1), new Path(input2));// 加载2个输入数据集  
  120.         Path outputPath = new Path(output);  
  121.         outputPath.getFileSystem(conf).delete(outputPath, true);  
  122.         FileOutputFormat.setOutputPath(job, outputPath);  
  123.   
  124.         System.exit(job.waitForCompletion(true) ? 0 : 1);  
  125.     }  
  126. }  

绘图演示效果:

MapReduce实现矩阵乘法

第二种数据结构

代码为:

[java]  view plain  copy
  1. import java.io.IOException;  
  2. import java.util.HashMap;  
  3. import java.util.Iterator;  
  4. import java.util.Map;  
  5.   
  6. import org.apache.hadoop.conf.Configuration;  
  7. import org.apache.hadoop.fs.Path;  
  8. import org.apache.hadoop.io.IntWritable;  
  9. import org.apache.hadoop.io.LongWritable;  
  10. import org.apache.hadoop.io.Text;  
  11. import org.apache.hadoop.mapreduce.Job;  
  12. import org.apache.hadoop.mapreduce.Mapper;  
  13. import org.apache.hadoop.mapreduce.Reducer;  
  14. import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;  
  15. import org.apache.hadoop.mapreduce.lib.input.FileSplit;  
  16. import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;  
  17. import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;  
  18. import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;  
  19.   
  20. /** 
  21.  * @author liuxinghao 
  22.  * @version 1.0 Created on 2014年10月10日 
  23.  */  
  24. public class SparseMatrixMultiply {  
  25.     public static class SMMapper extends Mapper<LongWritable, Text, Text, Text> {  
  26.         private String flag = null;  
  27.         private int m = 4;// 矩阵A的行数  
  28.         private int p = 2;// 矩阵B的列数  
  29.   
  30.         @Override  
  31.         protected void setup(Context context) throws IOException,  
  32.                 InterruptedException {  
  33.             FileSplit split = (FileSplit) context.getInputSplit();  
  34.             flag = split.getPath().getName();  
  35.         }  
  36.   
  37.         @Override  
  38.         protected void map(LongWritable key, Text value, Context context)  
  39.                 throws IOException, InterruptedException {  
  40.             String[] val = value.toString().split(",");  
  41.             if ("t1".equals(flag)) {  
  42.                 for (int i = 1; i <= p; i++) {  
  43.                     context.write(new Text(val[0] + "," + i), new Text("a,"  
  44.                             + val[1] + "," + val[2]));  
  45.                 }  
  46.             } else if ("t2".equals(flag)) {  
  47.                 for (int i = 1; i <= m; i++) {  
  48.                     context.write(new Text(i + "," + val[1]), new Text("b,"  
  49.                             + val[0] + "," + val[2]));  
  50.                 }  
  51.             }  
  52.         }  
  53.     }  
  54.   
  55.     public static class SMReducer extends  
  56.             Reducer<Text, Text, Text, IntWritable> {  
  57.         @Override  
  58.         protected void reduce(Text key, Iterable<Text> values, Context context)  
  59.                 throws IOException, InterruptedException {  
  60.             Map<String, String> mapA = new HashMap<String, String>();  
  61.             Map<String, String> mapB = new HashMap<String, String>();  
  62.   
  63.             for (Text value : values) {  
  64.                 String[] val = value.toString().split(",");  
  65.                 if ("a".equals(val[0])) {  
  66.                     mapA.put(val[1], val[2]);  
  67.                 } else if ("b".equals(val[0])) {  
  68.                     mapB.put(val[1], val[2]);  
  69.                 }  
  70.             }  
  71.   
  72.             int result = 0;  
  73.             // 可能在mapA中存在在mapB中不存在的key,或相反情况  
  74.             // 因为,数据定义的时候使用的是稀疏矩阵的定义  
  75.             // 所以,这种只存在于一个map中的key,说明其对应元素为0,不影响结果  
  76.             Iterator<String> mKeys = mapA.keySet().iterator();  
  77.             while (mKeys.hasNext()) {  
  78.                 String mkey = mKeys.next();  
  79.                 if (mapB.get(mkey) == null) {// 因为mkey取的是mapA的key集合,所以只需要判断mapB是否存在即可。  
  80.                     continue;  
  81.                 }  
  82.                 result += Integer.parseInt(mapA.get(mkey))  
  83.                         * Integer.parseInt(mapB.get(mkey));  
  84.             }  
  85.             context.write(key, new IntWritable(result));  
  86.         }  
  87.     }  
  88.   
  89.     public static void main(String[] args) throws IOException,  
  90.             ClassNotFoundException, InterruptedException {  
  91.         String input1 = "hdfs://192.168.1.128:9000/user/lxh/matrix/t1";  
  92.         String input2 = "hdfs://192.168.1.128:9000/user/lxh/matrix/t2";  
  93.         String output = "hdfs://192.168.1.128:9000/user/lxh/matrix/out";  
  94.   
  95.         Configuration conf = new Configuration();  
  96.         conf.addResource("classpath:/hadoop/core-site.xml");  
  97.         conf.addResource("classpath:/hadoop/hdfs-site.xml");  
  98.         conf.addResource("classpath:/hadoop/mapred-site.xml");  
  99.         conf.addResource("classpath:/hadoop/yarn-site.xml");  
  100.   
  101.         Job job = Job.getInstance(conf, "SparseMatrixMultiply");  
  102.         job.setJarByClass(SparseMatrixMultiply.class);  
  103.         job.setOutputKeyClass(Text.class);  
  104.         job.setOutputValueClass(Text.class);  
  105.   
  106.         job.setMapperClass(SMMapper.class);  
  107.         job.setReducerClass(SMReducer.class);  
  108.   
  109.         job.setInputFormatClass(TextInputFormat.class);  
  110.         job.setOutputFormatClass(TextOutputFormat.class);  
  111.   
  112.         FileInputFormat.setInputPaths(job, new Path(input1), new Path(input2));// 加载2个输入数据集  
  113.         Path outputPath = new Path(output);  
  114.         outputPath.getFileSystem(conf).delete(outputPath, true);  
  115.         FileOutputFormat.setOutputPath(job, outputPath);  
  116.   
  117.         System.exit(job.waitForCompletion(true) ? 0 : 1);  
  118.     }  
  119. }  

绘图演示效果:

MapReduce实现矩阵乘法

代码分析

  比较两种代码,可以很清楚的看出,两种实现只是在map阶段有些区别,reduce阶段基本相同。对于其中关于行i、列j定义不是从0计数(虽然我倾向于从0开始计数,不用写等号,简单),是为了更直观的观察数据处理过程是否符合设计。

  在第一种实现中,需要记录当前是读取的哪一行数据,所以,这种仅适用于不需要分块的小文件中进行的矩阵乘法运算。第二种实现中,每行数据记录了所在行所在列,不会有这方面的限制。

  在第二种实现中,遍历两个HashMap时,取mapA的key作为循环标准,是因为在一般情况下,mapA和mapB的key是相同的(如第一种实现),因为使用稀疏矩阵,两个不相同的key说明是0,可以舍弃不参与计算,所以只使用mapA的key,并判断mapB是否存在该key对应的值。

  两种实现的reduce阶段,计算最后结果时,都是直接使用内存存储数据、计算结果,所以当数据量很大的时候(通常都会很大,否则不会用分布式处理),极易造成内存溢出,所以,对于大矩阵的运算,还需要其他的转换方式,比如行列相乘运算、分块矩阵运算、基于最小粒度相乘的算法等方式。另外,因为这两份代码都是demo,所以代码中缺少过滤错误数据的部分。