多线程等待所有子线程执行完使用总结(2)——CountDownLatch使用和源码初步分析

时间:2023-02-16 14:56:09

问题背景

我们在日常开发和学习过程中,经常会使用到多线程的场景,其中我们经常会碰到,我们代码需要等待某个或者多个线程执行完再开始执行,上一篇文章中(参考 https://blog.51cto.com/baorant24/6059489 ),我们介绍了object的wait()和notify(),以及线程的join()方法来实现,本文将介绍一种新的方案,CountDownLatch类的使用。

问题分析

大家在日常开发和学习过程中,或多或少都使用过CountDownLatch,知道CountDownLatch的一般用法,CountDownLatch类是用在同步,允许一个或多个线程去等待直到另外的线程完成了一组操作。它通过count进行初始化,await方法会阻塞直到当前的count为0,之后所有的线程将被释放。了解CountDownLatch的一般作用后,我们先来一起看看CountDownLatch的源码相关分析。 重点源码梳理如下:

public class CountDownLatch {
    // 内部类Sync,继承自AbstractQueuedSynchronizer,并且重写了tryAcquireShared()和tryReleaseShared方法。
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c - 1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    // 构造方法传入count计数值,然后构造内部的sync对象
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    
    
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
    ...

    public void countDown() {
        sync.releaseShared(1);
    }
    
    // 获取计数值
    public long getCount() {
        return sync.getCount();
    }
    ...
}

核心代码分析: (1)await()方法分析 java.util.concurrent.CountDownLatch#await()


    public void await() throws InterruptedException {
       sync.acquireSharedInterruptibly(1);
   }

java.util.concurrent.locks.AbstractQueuedSynchronizer#acquireSharedInterruptibly


    public final void acquireSharedInterruptibly(int arg)
           throws InterruptedException {
       if (Thread.interrupted())
           throw new InterruptedException();
       // 前面提到CountDownLatch类中Sync内部类对tryAcquireShared进行了重写
       if (tryAcquireShared(arg) < 0)
           doAcquireSharedInterruptibly(arg);
   }

java.util.concurrent.CountDownLatch.Sync#tryAcquireShared

              
        // 判断state值是否为0,0的话返回1,不然返回-1。
        protected int tryAcquireShared(int acquires) {
           return (getState() == 0) ? 1 : -1;
       }

所有上面 java.util.concurrent.locks.AbstractQueuedSynchronizer#acquireSharedInterruptibly方法会走doAcquireSharedInterruptibly(arg),继续分析该方法。 java.util.concurrent.locks.AbstractQueuedSynchronizer#doAcquireSharedInterruptibly


    private void doAcquireSharedInterruptibly(int arg)
       throws InterruptedException {
       final Node node = addWaiter(Node.SHARED);
       try {
           for (;;) {
               final Node p = node.predecessor();
               if (p == head) {
                   int r = tryAcquireShared(arg);
                   // 根据前面的分析知道,这块r在state不等于0时,返回-1,所有state未减少到0时,这块回事for循环自旋一直等待。
                   if (r >= 0) {
                       setHeadAndPropagate(node, r);
                       p.next = null; // help GC
                       return;
                   }
               }
               if (shouldParkAfterFailedAcquire(p, node) &&
                   parkAndCheckInterrupt())
                   throw new InterruptedException();
           }
       } catch (Throwable t) {
           cancelAcquire(node);
           throw t;
       }
   }

(2)countDown()方法分析 java.util.concurrent.CountDownLatch#countDown


    public void countDown() {
       sync.releaseShared(1);
   }

java.util.concurrent.locks.AbstractQueuedSynchronizer#releaseShared


   public final boolean releaseShared(int arg) {
       // 同样,前面提到,tryReleaseShared方法在CountDownLatch类中Sync内部类中进行了重写
       if (tryReleaseShared(arg)) {
           doReleaseShared();
           return true;
       }
       return false;
   }

java.util.concurrent.CountDownLatch.Sync#tryReleaseShared

        protected boolean tryReleaseShared(int releases) {
           // Decrement count; signal when transition to zero
           for (;;) {
               int c = getState();
               if (c == 0)
                   return false;
               int nextc = c - 1;
               // 通过compareAndSet保证state-1的同步正确性。
               if (compareAndSetState(c, nextc))
   	    // 当改到0后return TRUE,不然 return FALSE。
                   return nextc == 0;
           }
       }

问题解决

大概梳理了CountDownLatch的源码后,我们实践一个demo具体再看下,代码如下:

import android.os.Bundle
import android.util.Log
import androidx.appcompat.app.AppCompatActivity
import java.util.concurrent.CountDownLatch

class TestCountDownLatchActivity : AppCompatActivity() {
    private var mCountDownLatch: CountDownLatch? = null

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_test_count_down_latch)

        mCountDownLatch = CountDownLatch(3);

        Thread {
            try {
                Log.d("TAG", "Thread Main A start:" + mCountDownLatch?.count);
                mCountDownLatch?.await();
                Log.d("TAG", "Thread Main A end:" + mCountDownLatch?.count);
            } catch (e: InterruptedException) {
                e.printStackTrace();
            }
        }.start()

        Thread {
            try {
                Log.d("TAG", "Thread Main B start:" + mCountDownLatch?.count);
                mCountDownLatch?.await();
                Log.d("TAG", "Thread Main B end:" + mCountDownLatch?.count);
            } catch (e: InterruptedException) {
                e.printStackTrace();
            }
        }.start()

        Thread {
            Log.d("TAG", "Thread A start:" + mCountDownLatch?.count);
            mCountDownLatch?.countDown();
            Log.d("TAG", "Thread A end:" + mCountDownLatch?.count);
        }.start()

        Thread {
            Log.d("TAG", "Thread B start:" + mCountDownLatch?.count);
            mCountDownLatch?.countDown();
            Log.d("TAG", "Thread B end:" + mCountDownLatch?.count);
        }.start()

        Thread {
            Log.d("TAG", "Thread C start:" + mCountDownLatch?.count);
            mCountDownLatch?.countDown();
            Log.d("TAG", "Thread C end:" + mCountDownLatch?.count);
        }.start()
    }
}

运行结果如下: 多线程等待所有子线程执行完使用总结(2)——CountDownLatch使用和源码初步分析 结果分析: 线程Main A和线程Main B在执行过程中调用mCountDownLatch?.await()方法,会等待mCountDownLatch的计数值减少到0后再继续执行。其余三个线程Thread A、Thread B、 Thread C执行过程将CountDownLatch分别减1,所以这三个线程都执行完后,线程Main A和线程Main B才会继续执行。

问题总结

我们在日常开发和学习过程中,经常会使用到多线程的场景,其中经常会碰到,我们代码需要等待某个或者多个线程执行完再开始执行,上一篇文章中,我们介绍了object的wait()和notify(),以及线程的join()方法来实现,本文介绍了一种新的方案,CountDownLatch类的使用和源码的初步分析,有兴趣的同学可以进一步深入研究。