大矩阵乘法 ODPS MapReduce

时间:2023-01-09 08:07:04

假设现有矩阵A和矩阵B,矩阵C=A*B:
对于A的[i,j]元素:大矩阵乘法 ODPS MapReduce
对于B的[j,k]元素:大矩阵乘法 ODPS MapReduce
则:C[i,k]的值为:
大矩阵乘法 ODPS MapReduce
其中C[1,1]与C[2,2]的计算互不影响,可使用分布式计算MapReduce进行分解,将所有计算需要的元素集中在同一个key上面。

MatrixMultiMapReduce.java

package fresh_comp_offline;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;

import com.aliyun.odps.data.Record;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.mapred.JobClient;
import com.aliyun.odps.mapred.MapperBase;
import com.aliyun.odps.mapred.ReducerBase;
import com.aliyun.odps.mapred.RunningJob;
import com.aliyun.odps.mapred.conf.JobConf;
import com.aliyun.odps.mapred.utils.InputUtils;
import com.aliyun.odps.mapred.utils.OutputUtils;
import com.aliyun.odps.mapred.utils.SchemaUtils;

/**
* 大矩阵相乘
*
* @author wwhhf
*
*/

public class MatrixMultiMapReduce {
// 矩阵A的长
public static int n = 0;
// 矩阵A的宽和矩阵B的长
public static int m = 0;
// 矩阵B的宽
public static int k = 0;

public static class Node {
private Long i = null;
private Long j = null;
private Long val = null;

public Long getI() {
return i;
}

public void setI(Long i) {
this.i = i;
}

public Long getJ() {
return j;
}

public void setJ(Long j) {
this.j = j;
}

public Long getVal() {
return val;
}

public void setVal(Long val) {
this.val = val;
}

public Node(Long i, Long j, Long val) {
super();
this.i = i;
this.j = j;
this.val = val;
}

@Override
public String toString() {
return "Node [i=" + i + ", j=" + j + ", val=" + val + "]";
}
}

public static class MyComparator implements Comparator<Node> {

@Override
public int compare(Node o1, Node o2) {
if (o1.getI() == o2.getI()) {
return (int) (o1.getJ() - o2.getJ());
} else {
return (int) (o1.getI() - o2.getI());
}
}

}

public static class MatrixMultiMapper extends MapperBase {

private Record key = null;

@Override
public void setup(TaskContext context) throws IOException {
key = context.createMapOutputKeyRecord();
}

@Override
public void map(long recordNum, Record record, TaskContext context)
throws IOException {
String type = record.getString(0);
Long i = record.getBigint(1);
Long j = record.getBigint(2);
if ("a".equals(type) == true) {
// a矩阵
for (int k = 1; k <= k; k++) {
key.set(0, i);
key.set(1, k);
context.write(key, record);
}
} else {
// b矩阵
for (int k = 1; k <= n; k++) {
key.set(0, k);
key.set(1, j);
context.write(key, record);
}
}
}

}

public static class MatrixMultiReducer extends ReducerBase {

private Record output = null;

@Override
public void setup(TaskContext context) throws IOException {
this.output = context.createOutputRecord();
}

@Override
public void reduce(Record key, Iterator<Record> values,
TaskContext context) throws IOException {
List<Node> a = new ArrayList<>();
List<Node> b = new ArrayList<>();
while (values.hasNext()) {
Record val = values.next();
String type = val.getString(0);
if ("a".equals(type)) {
a.add(new Node(val.getBigint(1), val.getBigint(2), val
.getBigint(3)));
} else {
b.add(new Node(val.getBigint(1), val.getBigint(2), val
.getBigint(3)));
}
}
if (a.size() == b.size()) {
Comparator<Node> cmp = new MyComparator();
Collections.sort(a, cmp);
Collections.sort(b, cmp);

Long res = 0L;
for (int i = 0, lenght = a.size(); i < lenght; i++) {
res += a.get(i).getVal() * b.get(i).getVal();
}

output.set(0, "c");
output.set(1, key.get(0));
output.set(2, key.get(1));
output.set(3, res);
context.write(output);
}
}
}

public static void main(String[] args) throws Exception {
if (args.length == 3) {
n=Integer.valueOf(args[0]);
m=Integer.valueOf(args[1]);
k=Integer.valueOf(args[2]);
} else {
throw new Exception("Input Error");
}

JobConf job = new JobConf();

job.setMapOutputKeySchema(SchemaUtils.fromString("i:bigint,j:bigint"));
job.setMapOutputValueSchema(SchemaUtils
.fromString("type:string,i:bigint,j:bigint,value:bigint"));

InputUtils.addTable(TableInfo.builder().tableName("matrix").build(),
job);
OutputUtils.addTable(TableInfo.builder().tableName("matrix_out")
.build(), job);

job.setMapperClass(MatrixMultiMapper.class);
job.setReducerClass(MatrixMultiReducer.class);

RunningJob rj = JobClient.runJob(job);
rj.waitForCompletion();
}

}

input

a,1,1,1
a,1,2,2
a,1,3,3
a,1,4,4
a,2,1,5
a,2,2,6
a,2,3,7
a,2,4,8
a,3,1,9
a,3,2,10
a,3,3,11
a,3,4,12
a,4,1,13
a,4,2,14
a,4,3,15
a,4,4,16
a,5,1,17
a,5,2,18
a,5,3,19
a,5,4,20
b,1,1,1
b,1,2,2
b,1,3,3
b,1,4,4
b,1,5,5
b,2,1,6
b,2,2,7
b,2,3,8
b,2,4,9
b,2,5,10
b,3,1,11
b,3,2,12
b,3,3,13
b,3,4,14
b,3,5,15
b,4,1,16
b,4,2,17
b,4,3,18
b,4,4,19
b,4,5,20

output

c,1,1,110
c,1,2,120
c,1,3,130
c,1,4,140
c,1,5,150
c,2,1,246
c,2,2,272
c,2,3,298
c,2,4,324
c,2,5,350
c,3,1,382
c,3,2,424
c,3,3,466
c,3,4,508
c,3,5,550
c,4,1,518
c,4,2,576
c,4,3,634
c,4,4,692
c,4,5,750
c,5,1,654
c,5,2,728
c,5,3,802
c,5,4,876
c,5,5,950