java.util.concurrent 包源码分析之线程同步辅助

时间:2021-07-18 19:28:04

CyclicBarrier

CyclicBarrier是一个用于线程同步的辅助类,它允许一组线程等待彼此,直到所有线程都到达集合点,然后执行某个设定的任务。现实中有个很好的例子来形容:几个人约定了某个地方集中,然后一起出发去旅行。每个参与的人就是一个线程,CyclicBarrier就是那个集合点,所有人到了之后,就一起出发。

CyclicBarrier的构造函数有两个:

// parties是参与等待的线程的数量,barrierAction是所有线程达到集合点之后要做的动作
public CyclicBarrier(int parties, Runnable barrierAction);

// 达到集合点之后不执行操作的构造函数
public CyclicBarrier(int parties)

需要说明的是,CyclicBarrier只是记录线程的数目,CyclicBarrier是不创建任何线程的。线程是通过调用CyclicBarrier的await方法来等待其他线程,如果调用await方法的线程数目达到了预设值,也就是上面构造方法中的parties,CyclicBarrier就会开始执行barrierAction。

因此我们来看CyclicBarrier的核心方法dowait,也就是await方法调用的私有方法:

   private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            final Generation g = generation;

            if (g.broken)
                throw new BrokenBarrierException();

            if (Thread.interrupted()) {
                breakBarrier();
                throw new InterruptedException();
            }
           // count就是预设的parties,count减1的值表示还剩余几个
           // 线程没有达到该集合点
           int index = --count;
           // index为0表示所有的线程都已经达到集合点,这时
           // 占用最后一个线程,执行运行设定的任务
           if (index == 0) {
               boolean ranAction = false;
               try {
                   final Runnable command = barrierCommand;
                   if (command != null)
                       command.run();
                   ranAction = true;
                   // 唤醒其他等待的线程,
                   // 更新generation以便下一次运行
                   nextGeneration();
                   return 0;
               } finally {
                   // 如果运行任务时发生异常,设置状态为broken
                   // 并且唤醒其他等待的线程
                   if (!ranAction)
                       breakBarrier();
               }
           }

            // 还有线程没有调用await,进入循环等待直到其他线程
            // 达到集合点或者等待超时
            for (;;) {
                try {
                    // 如果没有设置超时,进行无超时的等待
                    if (!timed)
                        trip.await();
                    // 有超时设置,进行有超时的等待
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    // generation如果没有被更新表示还是当前的运行
                    // (generation被更新表示集合完毕并且任务成功),
                    // 在状态没有被设置为broken状态的情况下,遇到线程
                    // 中断异常表示当前线程等待失败,需要设置为broken
                    // 状态,并且抛出中断异常
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        // else对应的条件为:g != generation || g.broken
                        // 表示要么generation已经被更新意味着所有线程已经到达
                        // 集合点并且任务执行成功,要么就是是broken状态意味着
                        // 任务执行失败,无论哪种情况所有线程已经达到集合点,当
                        // 前线程要结束等待了,发生了中断异常,需要中断当前线程
                        // 表示遇到了中断异常。
                        Thread.currentThread().interrupt();
                    }
                }

                // 如果发现当前状态为broken,抛出异常
                if (g.broken)
                    throw new BrokenBarrierException();
                // generation被更新表示所有线程都已经达到集合点
                // 并且预设任务已经完成,返回该线程进入等待顺序号
                if (g != generation)
                    return index;
                // 等待超时,设置为broken状态并且抛出超时异常
                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            lock.unlock();
        }
    }
  1. 任何一个线程等待时发生异常,CyclicBarrier都将被设置为broken状态,运行都会失败

  2. 每次运行成功之后CyclicBarrier都会清理运行状态,这样CyclicBarrier可以重新使用

  3. 对于设置了超时的等待,在发生超时的时候会引起CyclicBarrier的broken

CountDownLatch

CountDownLatch同样也是一个线程同步的辅助类,同样适用上面的集合点的场景来解释,但是运行模式完全不同。CyclicBarrier是参与的所有的线程彼此等待,CountDownLatch则不同,CountDownLatch有一个导游线程在等待,每个线程报到一下即可无须等待,等到导游线程发现所有人都已经报到了,就结束了自己的等待。

CountDownLatch的构造方法允许指定参与的线程数量:

public CountDownLatch(int count)

参与线程使用countDown表示报到:

   public void countDown() {
        sync.releaseShared(1);
    }

看到releaseShared很容易使人联想到共享锁,那么试着用共享锁的运行模式来解释就简单得多了:和信号量的实现类似,CountDownLatch内置一下有限的共享锁。每个参与线程拥有一把共享锁,调用countDown就等于是释放了自己的共享锁,导游线程await等于一下子要拿回所有的共享锁。那么基于AbstractQueuedSynchronizer类来实现就很简单了:

   public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

在await时注意到数量是1,其实这个参数对于CountDownLatch实现的Sync类(AbstractQueuedSynchronizer的子类)来说是不起作用的,因为需要保证await获取共享锁时必须拿到所有的共享锁,这个参数也就变得没有意义了。看一下Sync的tryAcquireShared方法就明白了:

        protected int tryAcquireShared(int acquires) {
            // 和信号量Semaphore的实现一样,使用state来存储count,
            // 每次释放共享锁就把state减1,state为0表示所有的共享
            // 锁已经被释放。注意:这里的acquires参数不起作用
            return (getState() == 0) ? 1 : -1;
        }

因此Sync的tryReleaseShared就是更新state(每次state减1):

 protected boolean tryReleaseShared(int releases) {
            // 每次state减1,当state为0,返回false表示所有的共享锁都已经释放
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }

CyclicBarrier和CountDownLatch本质上来说都是多个线程同步的辅助工具,前者可以看成分布式的,后者可以看出是主从式。

Phaser

Phaser是JDK7新添加的线程同步辅助类,作用同CyclicBarrier,CountDownLatch类似,但是使用起来更加灵活:

  1. Parties是动态的。

  2. Phaser支持树状结构,即Phaser可以有一个父Phaser。

Phaser的构造函数涉及到两个参数:父Phaser和初始的parties,因此提供了4个构造函数:

public Phaser();
public Phaser(int parties);
public Phaser(Phaser parent);
public Phaser(Phaser parent, int parties);

因为Phaser的特色在在于动态的parties,因此首先来看动态更新parties是如何实现的。

Phaser提供了两个方法:register和bulkRegister,前者会添加一个需要同步的线程,后者会添加parties个需要同步的线程。

    public int register() {
        return doRegister(1);
    }

    // 增加了参数的检查
    public int bulkRegister(int parties) {
        if (parties < 0)
            throw new IllegalArgumentException();
        if (parties == 0)
            return getPhase();
        return doRegister(parties);
    }

两个方法都调用了doRegister方法,因此接下来就来看看doRegister方法。

在分析doRegister之前先来说说Phaser的成员变量:state,它存储了Phaser的状态信息:

private volatile long state;
  1. state的最高位是一个标志位,1表示Phaser的线程同步已经结束,0表示线程同步正在进行

  2. state的低32位中,低16位表示没有到达的线程数量,高16位表示Parties值

  3. state的高32位除了最高位之外的其他31位表示的Phaser的phase,可以理解为第多少次同步(从0开始计算)。

介绍完了state,来看方法doRegister:

 private int doRegister(int registrations) {
        // 把registrations值同时加到parties值和还未达到的线程数量中去
        long adj = ((long)registrations << PARTIES_SHIFT) | registrations;
        final Phaser parent = this.parent;
        int phase;
        for (;;) {
            long s = state;
            int counts = (int)s;
            int parties = counts >>> PARTIES_SHIFT;
            int unarrived = counts & UNARRIVED_MASK;
            // 超过了允许的最大parties
            if (registrations > MAX_PARTIES - parties)
                throw new IllegalStateException(badRegister(s));
            // 最高位为1,表示Phaser的线程同步已经结束
            else if ((phase = (int)(s >>> PHASE_SHIFT)) < 0)
                break;
            // Phaser中的parties不是0
            else if (counts != EMPTY) {
                // 如果当前Phaser没有父Phaser,或者如果有父Phaser,
                // 刷新自己的state值,如果刷新后的state没有变化。
                // 这里刷新子Phaser的原因在于,会出现父Phaser已经进入下一个phase
                // 而子Phaser却没有及时进入下一个phase的延迟现象
                if (parent == null || reconcileState() == s) {
                    // 如果所有线程都到达了,等待Phaser进入下一次同步开始
                    if (unarrived == 0)
                        root.internalAwaitAdvance(phase, null);
                    // 更新state成功,跳出循环完成注册
                    else if (UNSAFE.compareAndSwapLong(this, stateOffset,
                                                       s, s + adj))
                        break;
                }
            }
            // 第一次注册,且不是子Phaser
            else if (parent == null) {
                // 更新当前Phaser的state值成功则完成注册
                long next = ((long)phase << PHASE_SHIFT) | adj;
                if (UNSAFE.compareAndSwapLong(this, stateOffset, s, next))
                    break;
            }
            // 第一次注册到子Phaser
            else {
                // 锁定当前Phaser对象
                synchronized (this) {
                    // 再次检查state值,确保没有被更新
                    if (state == s) {
                        // 注册到父Phaser中去
                        parent.doRegister(1);
                        do { // 获取当前phase值
                            phase = (int)(root.state >>> PHASE_SHIFT);
                        } while (!UNSAFE.compareAndSwapLong
                                 (this, stateOffset, state,
                                  ((long)phase << PHASE_SHIFT) | adj));// 更新当前Phaser的state值
                        break;
                    }
                }
            }
        }
        return phase;
    }

看完了注册,那么来看同步操作的arrive,这里也涉及到两个方法:arrive和arriveAndDeregister,前者会等待其他线程的到达,后者则会立刻返回:

  public int arrive() {
        return doArrive(false);
    }

    public int arriveAndDeregister() {
        return doArrive(true);
    }

两个方法都调用了doArrive方法,区别在于参数一个是false,一个是true。那么来看doArrive:

  private int doArrive(boolean deregister) {
        // arrive需要把未到达的线程数减去1,
        // deregister为true,需要把parties值也减去1
        int adj = deregister ? ONE_ARRIVAL|ONE_PARTY : ONE_ARRIVAL;
        final Phaser root = this.root;
        for (;;) {
            // 如果是有父Phaser,首先刷新自己的state
            long s = (root == this) ? state : reconcileState();
            int phase = (int)(s >>> PHASE_SHIFT);
            int counts = (int)s;
            int unarrived = (counts & UNARRIVED_MASK) - 1;
            // 最高位为1,表示同步已经结束,返回phase值
            if (phase < 0)
                return phase;
            // 如果parties为0或者在此次arrive之前所有线程到达
            else if (counts == EMPTY || unarrived < 0) {
                // 对于非子Phaser来说,上述情况的arrive肯定是非法的
                // 对于子Phaser首先刷新一下状态再做检查
                if (root == this || reconcileState() == s)
                    throw new IllegalStateException(badArrive(s));
            }
            // 正常情况下,首先更新state
            else if (UNSAFE.compareAndSwapLong(this, stateOffset, s, s-=adj)) {
                // 所有线程都已经到达
                if (unarrived == 0) {
                    // 计算parties作为下一个phase的未到达的parties
                    long n = s & PARTIES_MASK;
                    int nextUnarrived = (int)n >>> PARTIES_SHIFT;
                    // 调用父Phaser的doArrive
                    if (root != this)
                        // 如果下一个phase的未到达的parties为0,则需要向
                        // 父Phaser取消注册
                        return parent.doArrive(nextUnarrived == 0);
                    // 正在进入下一个Phase,默认的实现是nextUnarrived为0
                    // 表示正在进入下一个Phase,因为下一个phase的parties
                    // 为0,需要等待parties不为0
                    if (onAdvance(phase, nextUnarrived))
                        // 正在等待下一个phase,设置状态为终止
                        n |= TERMINATION_BIT;
                    else if (nextUnarrived == 0)
                        // 下一个phase的parties为0,更新未到达的parties的值
                        n |= EMPTY;
                    else
                        // 更新下一个phase的未到达的parties的值
                        n |= nextUnarrived;
                    // phase值加1
                    n |= (long)((phase + 1) & MAX_PHASE) << PHASE_SHIFT;

                    // 更新state值
                    UNSAFE.compareAndSwapLong(this, stateOffset, s, n);

                    // 唤醒等待的线程
                    releaseWaiters(phase);
                }
                return phase;
            }
        }
    }

关于arrive还有一个方法:arriveAndAwaitAdvance。这个方法会等到下一个phase开始再返回,相等于doArrive方法添加了awaitAdvance方法的功能。基本逻辑和上面说的doArrive方法类似:

  public int arriveAndAwaitAdvance() {
        final Phaser root = this.root;
        for (;;) {
            long s = (root == this) ? state : reconcileState();
            int phase = (int)(s >>> PHASE_SHIFT);
            int counts = (int)s;
            int unarrived = (counts & UNARRIVED_MASK) - 1;
            if (phase < 0)
                return phase;
            else if (counts == EMPTY || unarrived < 0) {
                // 对于非子Phaser来说,因为可以等待下一个phase,
                // 所以不是非法arrive
                if (reconcileState() == s)
                    throw new IllegalStateException(badArrive(s));
            }
            else if (UNSAFE.compareAndSwapLong(this, stateOffset, s,
                                               s -= ONE_ARRIVAL)) {
                // 还有其他线程没有达到,就会等待直到下一个phase开始
                if (unarrived != 0)
                    return root.internalAwaitAdvance(phase, null);
                if (root != this)
                    return parent.arriveAndAwaitAdvance();
                long n = s & PARTIES_MASK;  // base of next state
                int nextUnarrived = (int)n >>> PARTIES_SHIFT;
                if (onAdvance(phase, nextUnarrived))
                    n |= TERMINATION_BIT;
                else if (nextUnarrived == 0)
                    n |= EMPTY;
                else
                    n |= nextUnarrived;
                int nextPhase = (phase + 1) & MAX_PHASE;
                n |= (long)nextPhase << PHASE_SHIFT;
                if (!UNSAFE.compareAndSwapLong(this, stateOffset, s, n))
                    return (int)(state >>> PHASE_SHIFT);
                releaseWaiters(phase);
                return nextPhase;
            }
        }
    }

所谓线程等待Phaser的当前phase结束并转到下一个phase的过程。Phaser提供了三个方法:

// 不可中断,没有超时的版本
public int awaitAdvance(int phase);

// 可以中断,没有超时的版本
public int awaitAdvanceInterruptibly(int phase);

// 可以中断,带有超时的版本
public int awaitAdvanceInterruptibly(int phase, long timeout, TimeUnit unit);

这三个版本的方法的实现大体类似,区别在于第二个版本多了中断异常,第三个版本多了中断异常和超时异常。

   public int awaitAdvance(int phase) {
        // 获取当前state
        final Phaser root = this.root;
        long s = (root == this) ? state : reconcileState();
        int p = (int)(s >>> PHASE_SHIFT);

        // 检查给定的phase是否和当前的phase一直
        if (phase < 0)
            return phase;
        if (p == phase)
            return root.internalAwaitAdvance(phase, null);
        return p;
    }

    // 多了一个对于中断的检查然后抛出中断异常
    public int awaitAdvanceInterruptibly(int phase)
        throws InterruptedException {
        final Phaser root = this.root;
        long s = (root == this) ? state : reconcileState();
        int p = (int)(s >>> PHASE_SHIFT);
        if (phase < 0)
            return phase;
        if (p == phase) {
            // 使用QNode实现中断和超时,这里不带超时
            QNode node = new QNode(this, phase, true, false, 0L);
            p = root.internalAwaitAdvance(phase, node);
            // 对于中断的情况,抛出中断异常
            if (node.wasInterrupted)
                throw new InterruptedException();
        }
        return p;
    }

    // 多了中断异常和超时异常
    public int awaitAdvanceInterruptibly(int phase,
                                         long timeout, TimeUnit unit)
        throws InterruptedException, TimeoutException {
        long nanos = unit.toNanos(timeout);
        final Phaser root = this.root;
        long s = (root == this) ? state : reconcileState();
        int p = (int)(s >>> PHASE_SHIFT);
        if (phase < 0)
            return phase;
        if (p == phase) {
            QNode node = new QNode(this, phase, true, true, nanos);
            p = root.internalAwaitAdvance(phase, node);
            // 中断异常
            if (node.wasInterrupted)
                throw new InterruptedException();
            // 没有进入下一个phase,抛出超时异常
            else if (p == phase)
                throw new TimeoutException();
        }
        return p;
    }

上述三个方法都是调用了internalAwaitAdvance方法来实现等待,因此来看internalAwaitAdvance方法:

  private int internalAwaitAdvance(int phase, QNode node) {
        // 释放上一个phase的资源
        releaseWaiters(phase-1);

        // node是否被加入到队列中
        boolean queued = false;

        // 记录前一个Unarrived,用来增加spin值
        int lastUnarrived = 0;
        int spins = SPINS_PER_ARRIVAL;
        long s;
        int p;

        // 循环操作直到phase值发生了变化
        while ((p = (int)((s = state) >>> PHASE_SHIFT)) == phase) {
            // 不可中断的模式,使用自旋等待
            if (node == null) {
                int unarrived = (int)s & UNARRIVED_MASK;
                if (unarrived != lastUnarrived &&
                    (lastUnarrived = unarrived) < NCPU)
                    spins += SPINS_PER_ARRIVAL;
                boolean interrupted = Thread.interrupted();
                // 发生了中断时,使用一个node来记录这个中断
                if (interrupted || --spins < 0) {
                    node = new QNode(this, phase, false, false, 0L);
                    node.wasInterrupted = interrupted;
                }
            }
            // 当前线程的node可以结束等待了,后面会分析isReleasible方法
            else if (node.isReleasable())
                break;
            // 把node加入到队列中
            else if (!queued) {
                // 根据phase值不同,使用不同的队列
                AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;
                QNode q = node.next = head.get();
                // 检查队列的phase是否和要求的phase一致并且Phaser的phase没有发生变化
                // 符合这两个条件才把node添加到队列中去
                if ((q == null || q.phase == phase) &&
                    (int)(state >>> PHASE_SHIFT) == phase)
                    queued = head.compareAndSet(q, node);
            }
            // node加入队列后直接等待
            else {
                try {
            // 对于普通线程来说,这个方法作用就是循环直到isReleasable返回true
            // 或者block方法返回true
                    ForkJoinPool.managedBlock(node);
                } catch (InterruptedException ie) {
                    node.wasInterrupted = true;
                }
            }
        }

        // 对于进入队列的node,重置一些属性
        if (node != null) {
            // 释放thread,不要再使用unpark
            if (node.thread != null)
                node.thread = null;
            // 对于不可中断模式下发生的中断,清除中断状态
            if (node.wasInterrupted && !node.interruptible)
                Thread.currentThread().interrupt();
            // phase依旧没有变化表明同步过程被终止了
            if (p == phase && (p = (int)(state >>> PHASE_SHIFT)) == phase)
                return abortWait(phase);
        }

        // 通知所有的等待线程
        releaseWaiters(phase);
        return p;
    }

下面来看QNode,它实现了ManagedBlocker接口(见ForkJoinPool),ManagedBlocker包含两个方法:isReleasable和block。

isReleasable表示等待可以结束了,下面是QNode实现的isReleasable:

     public boolean isReleasable() {
            // 没了等待线程,通常会在外部使用"node.thread = null"来释放等待线程,这时可以结束等待
            if (thread == null)
                return true;
            // phase发生变化,可以结束等待
            if (phaser.getPhase() != phase) {
                thread = null;
                return true;
            }

            // 可中断的情况下发生线程中断,可以结束等待
            if (Thread.interrupted())
                wasInterrupted = true;
            if (wasInterrupted && interruptible) {
                thread = null;
                return true;
            }

            // 设置超时的情况下,发生超时,可以结束等待
            if (timed) {
                if (nanos > 0L) {
                    long now = System.nanoTime();
                    nanos -= now - lastTime;
                    lastTime = now;
                }
                if (nanos <= 0L) {
                    thread = null;
                    return true;
                }
            }
            return false;
        }

最后来看QNode实现的block方法,核心思想是用LockSupport来实现线程等待:

       public boolean block() {
            if (isReleasable())
                return true;
            // 没有设置超时的情况
            else if (!timed)
                LockSupport.park(this);
            // 设置超时的情况
            else if (nanos > 0)
                LockSupport.parkNanos(this, nanos);
            return isReleasable();
        }

最后来看releaseWaiters方法,看看怎么释放node队列:

    private void releaseWaiters(int phase) {
        QNode q;
        Thread t;
        AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;

        // 如果phase已经发生了变化,才能释放
        while ((q = head.get()) != null &&
               q.phase != (int)(root.state >>> PHASE_SHIFT)) {
            // 释放节点并转到下一个节点
            if (head.compareAndSet(q, q.next) &&
                (t = q.thread) != null) {
                // 释放线程
                q.thread = null;
                // 通知线程结束等待
                LockSupport.unpark(t);
            }
        }
    }