一篇带给你CountDownLatch实现原理

时间:2021-07-29 07:02:53

一篇带给你CountDownLatch实现原理

 前言

 

CountDownLatch是多线程中一个比较重要的概念,它可以使得一个或多个线程等待其他线程执行完毕之后再执行。它内部有一个计数器和一个阻塞队列,每当一个线程调用countDown()方法后,计数器的值减少1。当计数器的值不为0时,调用await()方法的线程将会被加入到阻塞队列,一直阻塞到计数器的值为0。

常用方法

 

  1. public class CountDownLatch { 
  2.  
  3.     //构造一个值为count的计数器 
  4.     public CountDownLatch(int count); 
  5.  
  6.     //阻塞当前线程直到计数器为0 
  7.     public void await() throws InterruptedException; 
  8.  
  9.     //在单位为unit的timeout时间之内阻塞当前线程 
  10.     public boolean await(long timeout, TimeUnit unit); 
  11.  
  12.     //将计数器的值减1,当计数器的值为0时,阻塞队列内的线程才可以运行 
  13.     public void countDown();       
  14.  

下面给一个简单的示例:

  1. package com.yang.testCountDownLatch; 
  2.  
  3. import java.util.concurrent.CountDownLatch; 
  4.  
  5. public class Main { 
  6.     private static final int NUM = 3; 
  7.  
  8.     public static void main(String[] args) throws InterruptedException { 
  9.         CountDownLatch latch = new CountDownLatch(NUM); 
  10.         for (int i = 0; i < NUM; i++) { 
  11.             new Thread(() -> { 
  12.                 try { 
  13.                     Thread.sleep(2000); 
  14.                     System.out.println(Thread.currentThread().getName() + "运行完毕"); 
  15.                 } catch (InterruptedException e) { 
  16.                     e.printStackTrace(); 
  17.                 } finally { 
  18.                     latch.countDown(); 
  19.                 } 
  20.             }).start(); 
  21.         } 
  22.         latch.await(); 
  23.         System.out.println("主线程运行完毕"); 
  24.     } 

输出如下:

一篇带给你CountDownLatch实现原理

看得出来,主线程会等到3个子线程执行完毕才会执行。

原理解析

 

类图

一篇带给你CountDownLatch实现原理

可以看得出来,CountDownLatch里面有一个继承AQS的内部类Sync,其实是AQS来支持CountDownLatch的各项操作的。

CountDownLatch(int count)

 

new CountDownLatch(int count)用来创建一个AQS同步队列,并将计数器的值赋给了AQS的state。

  1. public CountDownLatch(int count) { 
  2.     if (count < 0) throw new IllegalArgumentException("count < 0"); 
  3.     this.sync = new Sync(count); 
  4.  
  5. private static final class Sync extends AbstractQueuedSynchronizer {      
  6.     Sync(int count) { 
  7.         setState(count); 
  8.     } 
  9.  

countDown()

 

countDown()方法会对计数器进行减1的操作,当计数器值为0时,将会唤醒在阻塞队列中等待的所有线程。其内部调用了Sync的releaseShared(1)方法

  1. public void countDown() { 
  2.      sync.releaseShared(1); 
  3.  } 
  4.  
  5.  public final boolean releaseShared(int arg) { 
  6.      if (tryReleaseShared(arg)) { 
  7.          //此时计数器的值为0,唤醒所有被阻塞的线程 
  8.          doReleaseShared(); 
  9.          return true
  10.      } 
  11.      return false
  12.  } 

tryReleaseShared(arg)内部使用了自旋+CAS操将计数器的值减1,当减为0时,方法返回true,将会调用doReleaseShared()方法。对CAS机制不了解的同学,可以先参考我的另外一篇文章浅探CAS实现原理

  1. protected boolean tryReleaseShared(int releases) { 
  2.       //自旋 
  3.       for (;;) { 
  4.           int c = getState(); 
  5.           if (c == 0) 
  6.               //此时计数器的值已经为0了,其他线程早就执行完毕了,当前线程也已经再执行了,不需要再次唤醒了 
  7.               return false
  8.           int nextc = c-1; 
  9.           //使用CAS机制,将state的值变为state-1 
  10.           if (compareAndSetState(c, nextc)) 
  11.               return nextc == 0; 
  12.       } 
  13.   } 

doReleaseShared()是AQS中的方法,该方法会唤醒队列中所有被阻塞的线程。

  1. private void doReleaseShared() { 
  2.      for (;;) { 
  3.          Node h = head; 
  4.          if (h != null && h != tail) { 
  5.              int ws = h.waitStatus; 
  6.              if (ws == Node.SIGNAL) { 
  7.                  if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) 
  8.                      continue;            // loop to recheck cases 
  9.                  unparkSuccessor(h); 
  10.              } 
  11.              else if (ws == 0 && 
  12.                       !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) 
  13.                  continue;                // loop on failed CAS 
  14.          } 
  15.          if (h == head)                   // loop if head changed 
  16.              break; 
  17.      } 
  18.  } 

这段方法比较难理解,会另外篇幅介绍。这里只要认为该段方法会唤醒所有因调用await()方法而阻塞的线程。

await()

 

当计数器的值不为0时,该方法会将当前线程加入到阻塞队列中,并把当前线程挂起。

  1. public void await() throws InterruptedException { 
  2.     sync.acquireSharedInterruptibly(1); 

同样是委托内部类Sync,调用其

acquireSharedInterruptibly()方法

  1. public final void acquireSharedInterruptibly(int arg) 
  2.           throws InterruptedException { 
  3.       if (Thread.interrupted()) 
  4.           throw new InterruptedException(); 
  5.       if (tryAcquireShared(arg) < 0) 
  6.           doAcquireSharedInterruptibly(arg); 
  7.   } 

接着看Sync内的tryAcquireShared()方法,如果当前计数器的值为0,则返回1,最终将导致await()不会将线程阻塞。如果当前计数器的值不为0,则返回-1。

  1. protected int tryAcquireShared(int acquires) { 
  2.         return (getState() == 0) ? 1 : -1; 
  3.     } 

tryAcquireShared方法返回一个负值时,将会调用AQS中的

doAcquireSharedInterruptibly()方法,将调用await()方法的线程加入到阻塞队列中,并将此线程挂起。

  1. private void doAcquireSharedInterruptibly(int arg) 
  2.       throws InterruptedException { 
  3.       //将当前线程构造成一个共享模式的节点,并加入到阻塞队列中 
  4.       final Node node = addWaiter(Node.SHARED); 
  5.       boolean failed = true
  6.       try { 
  7.           for (;;) { 
  8.               final Node p = node.predecessor(); 
  9.               if (p == head) {         
  10.                   int r = tryAcquireShared(arg); 
  11.                   if (r >= 0) { 
  12.                       setHeadAndPropagate(node, r); 
  13.                       p.next = null; // help GC 
  14.                       failed = false
  15.                       return
  16.                   } 
  17.               } 
  18.               if (shouldParkAfterFailedAcquire(p, node) && 
  19.                   parkAndCheckInterrupt()) 
  20.                   throw new InterruptedException(); 
  21.           } 
  22.       } finally { 
  23.           if (failed) 
  24.               cancelAcquire(node); 
  25.       } 
  26.   } 

同样,以上的代码位于AQS中,在没有了解AQS结构的情况下去理解上述代码,有些困难,关于AQS源码,会另开篇幅介绍。

使用场景

 

CountDownLatch的使用场景很广泛,一般用于分头做某些事,再汇总的情景。例如:

数据报表:当前的微服务架构十分流行,大多数项目都会被拆成若干的子服务,那么报表服务在进行统计时,需要向各个服务抽取数据。此时可以创建与服务数相同的线程数,交由线程池处理,每个线程去对应服务中抽取数据,注意需要在finally语句块中进行countDown()操作。主线程调用await()阻塞,直到所有数据抽取成功,最后主线程再进行对数据的过滤组装等,形成直观的报表。

风险评估:客户端的一个同步请求查询用户的风险等级,服务端收到请求后会请求多个子系统获取数据,然后使用风险评估规则模型进行风险评估。如果使用单线程去完成这些操作,这个同步请求超时的可能性会很大,因为服务端请求多个子系统是依次排队的,请求子系统获取数据的时间是线性累加的。此时可以使用CountDownLatch,让多个线程并发请求多个子系统,当获取到多个子系统数据之后,再进行风险评估,这样请求子系统获取数据的时间就等于最耗时的那个请求的时间,可以大大减少处理时间。

原文地址:https://www.toutiao.com/i6919384508672295436/?group_id=6919384508672295436