乱弹java并发(九)-- fork/join

时间:2022-09-21 10:14:27

ForkJoin是在java7引入的由Doug Lea设计的一个并行计算框架,ForkJoin框架包含两部分:1、把任务Fork成一系列的相互之间无依赖的递归子任务,不同的子任务可以由不同的CPU核心执行;2、合并(Join)子任务的执行结果。Doug Lea在ForkJoin中引入一个叫work steal的算法,即当某线程的任务队列已执行完时,可以扫描其它线程的任务队列,从其它线程的任务队列队尾取出任务执行,由于线程取自己的任务队列时从头部开始取的,所以当队列大小大于1时,work steal时不会造成结点竞争的。通过work steal的方式来均衡各线程的计算量,只有当参与steal的线程由不同的CPU(或CPU核心)执行时,work steal才能提高CPU的吞吐量,否则这种均衡线程计算量的做法毫无意义,下面是一个ForkJoin的例子,用来实现归并排序,注意只是个例子,并不是最佳实现,事实上下面用ForkJoin框架实现的归并比正常串行执行的归并排序还要慢。

public class ParallelSort {

private static final int SORT_THREHOLD = 7;

public static void main(String[] args) {
int[] array = initArray(1000);
long startTime = System.currentTimeMillis();
sort(array);
System.out.println(System.currentTimeMillis() - startTime);
printArray(array);
}

public static void sort(int[] array) {
ForkJoinPool pool = new ForkJoinPool();
int[] copy = Arrays.copyOf(array, array.length);
SortTask task = new SortTask(copy, array, 0, array.length);
pool.submit(task);
try {
task.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}

private static int[] initArray(int count) {
int[] array = new int[count];
for (int i = count - 1; i >= 0; i--) {
array[count - i - 1] = i;
}
return array;
}

private static void printArray(int[] array) {
for (int a : array) {
System.out.println(a);
}
}

static class SortTask extends RecursiveAction {

/**
*
*/
private static final long serialVersionUID = 2761626835686763286L;
private int[] src;
private int[] dest;
private int low;
private int high;

public SortTask(int[] src, int[] dest, int low, int high) {
super();
this.src = src;
this.dest = dest;
this.low = low;
this.high = high;
}

@Override
protected void compute() {
// System.out.println("thread name:"
// + Thread.currentThread().getName());
mergeSort();
}

private void mergeSort() {

int length = high - low;
if (length <= SORT_THREHOLD) {
for (int i = low; i < high - 1; i++) {
for (int j = i; j < high; j++) {
if (dest[i] > dest[j]) {
swap(dest, i, j);
}
}
}
return;
}
int mid = (low + high) >>> 1;
// mergeSort(dest, src, low, mid);
// mergeSort(dest, src, mid, high);
SortTask lowTask = new SortTask(dest, src, low, mid);
SortTask highTask = new SortTask(dest, src, mid, high);
lowTask.fork();
highTask.fork();
lowTask.join();
highTask.join();
if (src[mid - 1] < src[mid]) {
for (int i = low; i < high; i++) {
dest[i] = src[i];
}
return;
}
int i = low;
int p = low;
int q = mid;
while (p < mid) {
if (q >= high || src[p] < src[q]) {
dest[i++] = src[p++];
} else {
dest[i++] = src[q++];
}
}

}

private static void swap(int[] array, int i, int j) {
int temp = array[i];
array[i] = array[j];
array[j] = temp;
}

}
}