实现ThreadFactory接口生成自定义的线程给Fork/Join框架

时间:2023-03-08 16:17:24

  Fork/Join框架是Java7中最有趣的特征之一。它是Executor和ExecutorService接口的一个实现,允许你执行Callable和Runnable任务而不用管理这些执行线程。这个执行者面向执行能被拆分成更小部分的任务。主要组件如下:

  • 一个特殊任务,实现ForkJoinTask类
  • 两种操作,将任务划分成子任务的fork操作和等待这些子任务结束的join操作
  • 一个算法,优化池中线程的使用的work-stealing算法。当一个任务正在等待它的子任务(结束)时,它的执行线程将执行其他任务(等待执行的任务)。

  ForkJoinPool类是Fork/Join的主要类。在它的内部实现,有如下两种元素:

  • 一个存储等待执行任务的列队。
  • 一个执行任务的线程池

  在这个指南中,你将学习如何实现一个在ForkJoinPool类中使用的自定义的工作者线程,及如何使用一个工厂来使用它。

   要自定义ForkJoinPool类使用的线程,必须继承ForkJoinWorkerThread

public class MyWorkerThread extends ForkJoinWorkerThread {
private static ThreadLocal<Integer> taskCounter = new ThreadLocal<Integer>();
protected MyWorkerThread(ForkJoinPool pool) {
super(pool);
}
@Override
protected void onStart() {
super.onStart();
System.out.println("MyWorkerThread " + getId()+ " : Initializing task counter");
taskCounter.set(0);
}
@Override
protected void onTermination(Throwable exception) {
System.out.println("MyWorkerThread " + getId() + " :"
+ taskCounter.get());
super.onTermination(exception);
}
public void addTask() {
int counter = taskCounter.get().intValue();
counter++;
taskCounter.set(counter);
}
}

  继承ForkJoinWorkerThreadFactory创建MyWorkerThreadFactory工厂

public class MyWorkerThreadFactory implements ForkJoinWorkerThreadFactory {
@Override
public MyWorkerThread newThread(ForkJoinPool pool) {
return new MyWorkerThread(pool);
}
}
public class MyRecursiveTask extends RecursiveTask<Integer> {
private int array[];
private int start, end;
public MyRecursiveTask(int[] array, int start, int end) {
super();
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
Integer ret;
MyWorkerThread thread = (MyWorkerThread) Thread.currentThread();
thread.addTask();
if (end - start > 100) {
int mid = (start + end) / 2;
MyRecursiveTask task1 = new MyRecursiveTask(array, start, mid);
MyRecursiveTask task2 = new MyRecursiveTask(array, mid, end);
invokeAll(task1, task2);
ret = addResults(task1, task2);
} else {
int add = 0;
for (int i = start; i < end; i++) {
add += array[i];
}
ret = new Integer(add);
}
try {
TimeUnit.MILLISECONDS.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
return ret;
} private Integer addResults(MyRecursiveTask task1, MyRecursiveTask task2) {
int value = 0;
try {
value = task1.get().intValue() + task2.get().intValue();
} catch (Exception e) {
e.printStackTrace();
}
try {
TimeUnit.MILLISECONDS.sleep(10);
} catch (Exception e) {
e.printStackTrace();
}
return value;
}
}
public class ForkMain {
public static void main(String[] args) throws Exception {
MyWorkerThreadFactory factory=new MyWorkerThreadFactory();
ForkJoinPool joinPool=new ForkJoinPool(4, factory, null, false);
int array[]=new int[100000];
for (int i =0; i <100000; i++) {
array[i]=i;
}
MyRecursiveTask task=new MyRecursiveTask(array, 0, 10000);
joinPool.execute(task);
task.join();
joinPool.shutdown();
joinPool.awaitTermination(1, TimeUnit.DAYS);
System.out.println("Main: Result:"+task.get());
System.out.println("Main:Ends");
}
}