部分内容来自以下博客:
https://segmentfault.com/a/1190000016781127
https://segmentfault.com/a/1190000016877931
1 简介JDK1.7版本引入了一套Fork/Join框架。Fork/Join框架的基本思想就是将一个大任务分解(Fork)成一系列子任务,子任务可以继续往下分解,当多个不同的子任务都执行完成后,可以将它们各自的结果合并(Join)成一个大结果,最终合并成大任务的结果。
Fork/Join 框架要完成两件事情:
1)Fork:把一个复杂任务进行分拆
2)Join:把分拆任务的结果进行合并
Fork/Join框架的实现非常复杂,内部大量运用了位操作和无锁算法。
Fork/Join框架内部还涉及到三大核心组件:ForkJoinPool(线程池)、ForkJoinTask(任务)、ForkJoinWorkerThread(工作线程),外加WorkQueue(任务队列)。
2 类和接口 2.1 ForkJoinPoolForkJoinPool是分支合并池,类似于线程池ThreadPoolExecutor,同样是ExecutorService接口的一个实现类。
ForkJoinPool类的实现:
public class ForkJoinPool extends AbstractExecutorService {
在ForkJoinPool类中提供了三个构造方法:
public ForkJoinPool(); public ForkJoinPool(int parallelism); public ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode);
最终调用的是下面这个私有构造器:
private ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, int mode, String workerNamePrefix);
其参数含义如下:
parallelism:并行级别,默认值为CPU核心数,ForkJoinPool里工作线程数量与该参数有关,但它不表示最大线程数。
factory:工作线程工厂,默认是DefaultForkJoinWorkerThreadFactory,其实就是用来创建ForkJoinWorkerThread工作线程对象。
handler:异常处理器。
mode:调度模式,true表示FIFO_QUEUE,false表示LIFO_QUEUE。
workerNamePrefix:工作线程的名称前缀。
2.2 ForkJoinTaskForkJoinTask是Future接口的抽象实现类,提供了用于分解任务的fork()方法和用于合并任务的join()方法。
在ThreadPoolExecutor类中,使用线程池执行任务调用的execute()方法中要求传入Runnable接口的实例。但是在ForkJoinPool类中,除了可以传入Runnable接口的实例外,还可以传入ForkJoinTask抽象类的实例,并且传入Runnable接口的实例也会被适配为ForkJoinTask抽象类的实例。
2.3 RecursiveTask通常情况下使用ForkJoinTask抽象类的实例,并不需要直接继承ForkJoinTask类,只需要继承其子类:
1)RecursiveAction:用于没有返回结果的任务
2)RecursiveTask:用于有返回结果的任务
其中,最常用的还是RecursiveTask类。
2.4 ForkJoinWorkerThreadForkJoinWorkerThread类是Thread的子类,作为线程池中的工作线程执行任务,其内部维护了一个WorkerQueue类型的双向任务队列。
工作线程在执行任务时,优先处理自身任务队列中的任务(FIFO或者LIFO),当自身队列中的任务为空时,会窃取其他任务队列中的任务(FIFO)。
2.5 WorkerQueueWorkerQueue类是ForkJoinPool类的一个内部类,代表存储ForkJoinTask实例的双端队列。
在ForkJoinPool类的私有构造方法中,有一个int类型的mode参数,其取值如下:
static final int LIFO_QUEUE = 0; static final int FIFO_QUEUE = 1 << 16;
当入参为LIFO_QUEUE时,表示同步,对于工作线程(Worker)自身队列中的任务,采用后进先出(LIFO)的方式执行。
当入参为FIFO_QUEUE时,表示异步,对于工作线程(Worker)自身队列中的任务,采用先进先出(FIFO)的方式执行。
3 实现原理 3.1 提交任务使用ForkJoinPool的submit方法提交任务得到ForkJoinTask对象:
publicForkJoinTask submit(ForkJoinTask task) { if (task == null) throw new NullPointerException(); externalPush(task); return task; }
继续查看externalPush方法:
final void externalPush(ForkJoinTask> task) {
WorkQueue[] ws; WorkQueue q; int m;
int r = ThreadLocalRandom.getProbe();
int rs = runState;
if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
(q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
U.compareAndSwapInt(q, QLOCK, 0, 1)) {
ForkJoinTask>[] a; int am, n, s;
if ((a = q.array) != null &&
(am = a.length - 1) > (n = (s = q.top) - q.base)) {
int j = ((am & s) << ASHIFT) + Abase;
U.putOrderedObject(a, j, task);
U.putOrderedInt(q, QTOP, s + 1);
U.putIntVolatile(q, QLOCK, 0);
if (n <= 1)
signalWork(ws, q);
return;
}
U.compareAndSwapInt(q, QLOCK, 1, 0);
}
externalSubmit(task);
}
该方法包含两个部分:
1)尝试将任务添加到任务队列,添加后则创建或激活一个工作线程,在此过程中使用了CAS保证线程安全。
2)添加队列失败,则调用externalSubmit方法初始化队列,并将任务加入到队列。
3.2 分解任务 3.2.1 创建或唤醒工作线程调用ForkJoinTask的fork方法完成任务分解:
public final ForkJoinTaskfork() { Thread t; if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)// 调用线程为工作线程 ((ForkJoinWorkerThread)t).workQueue.push(this);// 将任务添加到自身队列 else ForkJoinPool.common.externalPush(this);// 调用ForkJoinPool的externalPush方法 return this; }
该方法包含两个部分:
1)调用线程为工作线程,将任务添加到自身队列。
2)调用线程为其他外部线程,继续调用ForkJoinPool的externalPush方法,尝试将任务添加到任务队列并激活工作线程。
继续查看push方法,添加任务到自身队列:
final void push(ForkJoinTask> task) {
ForkJoinTask>[] a; ForkJoinPool p;
int b = base, s = top, n;
if ((a = array) != null) { // ignore if queue removed
int m = a.length - 1; // fenced write for task visibility
U.putOrderedObject(a, ((m & s) << ASHIFT) + Abase, task);
U.putOrderedInt(this, QTOP, s + 1);
if ((n = s - b) <= 1) {
if ((p = pool) != null)
p.signalWork(p.workQueues, this);// 唤醒或创建工作线程
}
else if (n >= m)
growArray();// 扩容
}
}
该方法包含两个部分:
1)判断是否需要扩容,不需要扩容则唤醒或创建工作线程。
2)需要扩容,则进行扩容操作。
继续查看signalWork方法,创建或唤醒工作线程:
final void signalWork(WorkQueue[] ws, WorkQueue q) {
long c; int sp, i; WorkQueue v; Thread p;
while ((c = ctl) < 0L) { // too few active
if ((sp = (int)c) == 0) { // 没有空闲工作进程
if ((c & ADD_WORKER) != 0L) // 工作进程太少
tryAddWorker(c);// 增加工作进程
break;
}
// 有工作进程,唤醒
if (ws == null) // unstarted/terminated
break;
if (ws.length <= (i = sp & SMASK)) // terminated
break;
if ((v = ws[i]) == null) // terminating
break;
int vs = (sp + SS_SEQ) & ~INACTIVE; // next scanState
int d = sp - v.scanState; // screen CAS
long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
v.scanState = vs; // activate v
if ((p = v.parker) != null)
U.unpark(p);
break;
}
if (q != null && q.base == q.top) // no more work
break;
}
}
继续查看tryAddWorker方法:
private void tryAddWorker(long c) {
boolean add = false;
do {
// 设置活跃工作线程数和总工作线程数
long nc = ((AC_MASK & (c + AC_UNIT)) |
(TC_MASK & (c + TC_UNIT)));
if (ctl == c) {
int rs, stop; // check if terminating
if ((stop = (rs = lockRunState()) & STOP) == 0)
add = U.compareAndSwapLong(this, CTL, c, nc);
unlockRunState(rs, rs & ~RSLOCK);
if (stop != 0)
break;
if (add) {
// 创建工作线程
createWorker();
break;
}
}
} while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
}
继续查看createWorker方法:
private boolean createWorker() {
ForkJoinWorkerThreadFactory fac = factory;
Throwable ex = null;
ForkJoinWorkerThread wt = null;
try {
// 使用线程池工厂创建线程
if (fac != null && (wt = fac.newThread(this)) != null) {
// 启动线程
wt.start();
return true;
}
} catch (Throwable rex) {
ex = rex;
}
// 出现异常,注销该工作线程
deregisterWorker(wt, ex);
return false;
}
3.2.2 启动任务
ForkJoinWorkerThread在执行start方法后,会执行run方法:
public void run() {
if (workQueue.array == null) { // only run once
Throwable exception = null;
try {
onStart();
pool.runWorker(workQueue);
} catch (Throwable ex) {
exception = ex;
} finally {
try {
onTermination(exception);
} catch (Throwable ex) {
if (exception == null)
exception = ex;
} finally {
pool.deregisterWorker(this, exception);
}
}
}
}
在run方法内部调用了ForkJoinPool对象的runWorker方法:
final void runWorker(WorkQueue w) {
w.growArray(); // 初始化任务队列
int seed = w.hint; // initially holds randomization hint
int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
for (ForkJoinTask> t;;) {
if ((t = scan(w, r)) != null)// 尝试获取任务
w.runTask(t);// 执行任务
else if (!awaitWork(w, r))// 获取失败,加入等待任务队列
break;// 等待失败,跳出方法并注销工作线程
r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
}
}
3.2.3 窃取任务
使用scan方法窃取任务:
private ForkJoinTask> scan(WorkQueue w, int r) {
WorkQueue[] ws; int m;
if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
int ss = w.scanState; // initially non-negative
for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
WorkQueue q; ForkJoinTask>[] a; ForkJoinTask> t;
int b, n; long c;
if ((q = ws[k]) != null) {// 定位任务队列
if ((n = (b = q.base) - q.top) < 0 &&
(a = q.array) != null) { // non-empty
long i = (((a.length - 1) & b) << ASHIFT) + Abase;
if ((t = ((ForkJoinTask>)
U.getObjectVolatile(a, i))) != null &&
q.base == b) {
if (ss >= 0) {
if (U.compareAndSwapObject(a, i, t, null)) {
q.base = b + 1;
if (n < -1) // signal others
signalWork(ws, q);// 创建获唤醒工作线程执行任务
return t;
}
}
else if (oldSum == 0 && // try to activate
w.scanState < 0)
tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);// 唤醒栈顶工作线程
}
if (ss < 0) // refresh
ss = w.scanState;
r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
origin = k = r & m; // move and rescan
oldSum = checkSum = 0;
continue;
}
checkSum += b;
}
// 已扫描全部工作线程,但并未找到任务
if ((k = (k + 1) & m) == origin) { // continue until stable
if ((ss >= 0 || (ss == (ss = w.scanState))) &&
oldSum == (oldSum = checkSum)) {
if (ss < 0 || w.qlock < 0) // already inactive
break;
int ns = ss | INACTIVE; // 尝试对当前工作线程灭活
long nc = ((SP_MASK & ns) |
(UC_MASK & ((c = ctl) - AC_UNIT)));
w.stackPred = (int)c; // hold prev stack top
U.putInt(w, QSCANSTATE, ns);
if (U.compareAndSwapLong(this, CTL, c, nc))
ss = ns;
else
w.scanState = ss; // back out
}
checkSum = 0;
}
}
}
return null;
}
3.2.4 执行任务
窃取到任务后,调用runTask方法执行任务:
final void runTask(ForkJoinTask> task) {
if (task != null) {
scanState &= ~SCANNING; // mark as busy
(currentSteal = task).doExec();// 执行任务
U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
execLocalTasks();// 执行本地任务
ForkJoinWorkerThread thread = owner;
if (++nsteals < 0) // collect on overflow
transferStealCount(pool);// 增加窃取任务数
scanState |= SCANNING;
if (thread != null)
thread.afterTopLevelExec();// 执行钩子函数
}
}
3.2.5 阻塞等待
如何未窃取到任务,会调用awaitWork方法等待获取任务:
private boolean awaitWork(WorkQueue w, int r) {
if (w == null || w.qlock < 0) // w is terminating
return false;
for (int pred = w.stackPred, spins = SPINS, ss;;) {
if ((ss = w.scanState) >= 0)
break;
else if (spins > 0) {
r ^= r << 6; r ^= r >>> 21; r ^= r << 7;
if (r >= 0 && --spins == 0) { // randomize spins
WorkQueue v; WorkQueue[] ws; int s, j; AtomicLong sc;
if (pred != 0 && (ws = workQueues) != null &&
(j = pred & SMASK) < ws.length &&
(v = ws[j]) != null && // see if pred parking
(v.parker == null || v.scanState >= 0))
spins = SPINS; // continue spinning
}
}
else if (w.qlock < 0) // recheck after spins
return false;
else if (!Thread.interrupted()) {
long c, prevctl, parkTime, deadline;
int ac = (int)((c = ctl) >> AC_SHIFT) + (config & SMASK);
if ((ac <= 0 && tryTerminate(false, false)) ||
(runState & STOP) != 0) // pool terminating
return false;
if (ac <= 0 && ss == (int)c) { // is last waiter
prevctl = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & pred);
int t = (short)(c >>> TC_SHIFT); // shrink excess spares
if (t > 2 && U.compareAndSwapLong(this, CTL, c, prevctl))
return false; // else use timed wait
parkTime = IDLE_TIMEOUT * ((t >= 0) ? 1 : 1 - t);
deadline = System.nanoTime() + parkTime - TIMEOUT_SLOP;
}
else
prevctl = parkTime = deadline = 0L;
Thread wt = Thread.currentThread();
U.putObject(wt, PARKBLOCKER, this); // emulate LockSupport
w.parker = wt;
if (w.scanState < 0 && ctl == c) // recheck before park
U.park(false, parkTime);
U.putOrderedObject(w, QPARKER, null);
U.putObject(wt, PARKBLOCKER, null);
if (w.scanState >= 0)
break;
if (parkTime != 0L && ctl == c &&
deadline - System.nanoTime() <= 0L &&
U.compareAndSwapLong(this, CTL, c, prevctl))
return false; // shrink pool
}
}
return true;
}
3.3 合并任务
使用ForkJoinTask的join方法可以获取任务的执行结果:
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
查看doJoin方法:
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
return (s = status) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) && (s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) :
externalAwaitDone();
}
4 使用
任务类定义,因为需要返回结果,所以继承RecursiveTask,并覆写compute方法。
任务的拆分通过ForkJoinTask的fork方法执行,join方法用于等待任务执行后返回。
class SumTask extends RecursiveTask{ private static final int THRESHOLD = 10;// 拆分阈值 private int begin;// 拆分开始值 private int end;// 拆分结束值 public SumTask(int begin, int end) { this.begin = begin; this.end = end; } @Override protected Integer compute() { Integer value = 0; if (end - begin <= THRESHOLD) {// 小于阈值,直接计算 for (int i = begin; i <= end; i++) { value += i; } } else {// 大于阈值,递归计算 int middle = (begin + end) / 2; SumTask beginTask = new SumTask(begin, middle); SumTask endTask = new SumTask(middle + 1, end); beginTask.fork(); endTask.fork(); value = beginTask.join() + endTask.join(); } return value; } } public class DemoTest { public static void main(String[] args) { SumTask sumTask = new SumTask(1, 100); ForkJoinPool pool = new ForkJoinPool(); try { ForkJoinTask task = pool.submit(sumTask); System.out.println(task.get()); } catch (Exception e) { e.printStackTrace(); } finally { pool.shutdown(); } } }
最终结果是5050。



