单元测试多线程解决之道

时间:2022-08-31 18:02:42

遇到问题

曾今在开发的过程遇到一个问题,当时有一个服务是群发邮件的,由于一次发送几十个上百个,所以就使用了多线程来操作。

在单元测试的时候,我调了这个方法测试下邮件发送,结果总是出现莫名其妙的问题,每次都没有全部发送成功。

后来我感觉到启动的子线程都被杀掉了,好像测试方法一走完就over了,试着在测试方法末尾让线程睡眠个几秒,结果就能正常发送邮件。

分析解决

感觉这个Junit有点猫腻,就上网查了一下,再跟踪下源码,果然发现了问题所在。

TestRunner的main方法:

public static void main(String[] args) {
    TestRunner aTestRunner = new TestRunner();

    try {
        TestResult r = aTestRunner.start(args);
        if (!r.wasSuccessful()) {
            System.exit(1);
        }

        System.exit(0);
    } catch (Exception var3) {
        System.err.println(var3.getMessage());
        System.exit(2);
    }

}

上面显示了,不管成功与否,都会调用 System.exit() 方法关闭程序,这个方法是用来结束当前正在运行中的java虚拟机。

System.exit(0) 是正常退出程序,而 System.exit(1) 或者说非0表示非正常退出程序。

由此可见,junit 并不适合用来测试多线程程序呢,但是也不是没有方法,根据其原理可以尝试让主线程阻塞一下,等待其他子线程执行完毕再继续。

最简单的方法就是让主线程睡眠个几秒钟:

TimeUnit.SECONDS.sleep(5);

回顾复盘

除了让主线程睡眠以外,其实还有很多其他的工具可以帮我们解决这个问题。今天想起来了,就来试试吧。

来个数据库连接池相关的测试:

public class MultipleConnectionTest{

    private HikariDataSource ds;


    @Before
    public void setup() {
        HikariConfig config = new HikariConfig();
        config.setJdbcUrl("jdbc:mysql://127.0.0.1:3306/design");
        config.setDriverClassName("com.mysql.jdbc.Driver");
        config.setUsername("root");
        config.setPassword("fengcs");
        config.setMinimumIdle(1);
        config.setMaximumPoolSize(5);

        ds = new HikariDataSource(config);
    }

    @After
    public void teardown() {
        ds.close();
    }

    @Test
    public void testMulConnection() {

        ConnectionThread connectionThread = new ConnectionThread();
        Thread thread = null;
        for (int i = 0; i < 5; i++) {
            thread = new Thread(connectionThread, "thread-con-" + i);
            thread.start();
        }

        // TimeUnit.SECONDS.sleep(5);  (1)
    }

    private class ConnectionThread implements Runnable{

        @Override
        public void run() {
            Connection connection = null;
            try {
                connection = ds.getConnection();
                Statement statement =  connection.createStatement();
                ResultSet resultSet = statement.executeQuery("select id from tb_user");
                String firstValue;
                System.out.println("<=============");
                System.out.println("==============>"+Thread.currentThread().getName() + ":");
                while (resultSet.next()) {
                    firstValue = resultSet.getString(1);
                    System.out.print(firstValue);
                }
            } catch (SQLException e) {
                e.printStackTrace();
            } finally {
                try {
                    if (connection != null) {
                        connection.close();
                    }
                } catch (SQLException e) {
                    e.printStackTrace();
                }
            }
        }
    }

}

这个代码一跑起来就会报错:

java.sql.SQLException: HikariDataSource HikariDataSource (HikariPool-1) has been closed.

1、使用 join 方法

根据上面的代码,直接加个 join 试试:

@Test
public void testMulConnection() {

    ConnectionThread connectionThread = new ConnectionThread();
    Thread thread = null;
    for (int i = 0; i < 5; i++) {
        thread = new Thread(connectionThread, "thread-con-" + i);
        thread.start();
        thread.join();
    }

}

这样虽然可以成功执行,但仔细一看,和单个线程执行没有什么区别。对于主线程来说,start一个就join一个,开始阻塞等待子线程完成,然后循环开始第二个操作。

正确的操作应该类似这样:

Thread threadA = new Thread(connectionThread);
Thread threadB = new Thread(connectionThread);
threadA.start();
threadB.start();
threadA.join();
threadB.join();

这样多个线程可以一起执行。不过线程多了,这样写比较麻烦。

2、闭锁 - CountDownLatch

CountDownLatch 允许一个或多个线程等待其他线程完成操作。

CountDownLatch 的构造函数接收一个int类型的参数作为计数器,如果你想等待N个点完成,这里就传入N。

那么在这里,很明显主线程应该等待其他五个线程完成查询后再关闭。那么加上(1)和(2)处的代码,让主线程阻塞等待。

private static CountDownLatch latch = new CountDownLatch(5);  // (1)

@Test
public void testMulConnection() throws InterruptedException {

    ConnectionThread connectionThread = new ConnectionThread();
    Thread thread = null;
    for (int i = 0; i < 5; i++) {
        thread = new Thread(connectionThread, "thread-con-"+i);
        thread.start();
    }

    latch.await();   // (2)

}

当我们调用CountDownLatch的countDown方法时,N就会减1,CountDownLatch的await方法
会阻塞当前线程,直到N变成零。增加(3)处代码,每个线程完成查询后就将计数器减一。

private class ConnectionThread implements Runnable{

    @Override
    public void run() {
        Connection connection = null;
        try {
            connection = ds.getConnection();
            Statement statement =  connection.createStatement();
            ResultSet resultSet = statement.executeQuery("select id from tb_user");
            String firstValue;
            System.out.println("<=============");
            System.out.println("==============>"+Thread.currentThread().getName() + ":");
            while (resultSet.next()) {
                firstValue = resultSet.getString(1);
                System.out.print(firstValue);
            }
            
            latch.countDown(); // (3)
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                if (connection != null) {
                    connection.close();
                }
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }
}

测试一下,完全满足要求。

3、栅栏- CyclicBarrier

CyclicBarrier 的字面意思是可循环使用(Cyclic)的屏障(Barrier)。它要做的事情是,让一
组线程到达一个屏障(也可以叫同步点)时被阻塞,直到最后一个线程到达屏障时,屏障才会
开门,所有被屏障拦截的线程才会继续运行。

这里和 CountDownLatch 有所不同,但是主线程需要阻塞,依然在main方法末尾处加上一个同步点:

private static CyclicBarrier cyclicBarrier = new CyclicBarrier(6);  // (1)

@Test
public void testMulConnection() throws BrokenBarrierException, InterruptedException {

    ConnectionThread connectionThread = new ConnectionThread();
    Thread thread = null;
    for (int i = 0; i < 5; i++) {
        thread = new Thread(connectionThread, "thread-con-"+i);
        thread.start();
    }

    cyclicBarrier.await();   // (2)

}

CyclicBarrier默认的构造方法是 CyclicBarrier(int parties),其参数表示屏障拦截的线程数量,每个线程调用await方法告诉CyclicBarrier我已经到达了屏障,然后当前线程被阻塞。

这个时候没有类似闭锁的 countDown 方法来计数,只能靠线程到达同步点来确认是否都到达,而其他线程不会走main方法的同步点,所以还需要一个其他五个线程汇合的同步点。那么可以在每个线程 run 方法末尾 await 一下:

private class ConnectionThread implements Runnable{

    @Override
    public void run() {
        Connection connection = null;
        try {
            connection = ds.getConnection();
            Statement statement =  connection.createStatement();
            ResultSet resultSet = statement.executeQuery("select id from tb_user");
            String firstValue;
            System.out.println("<=============");
            System.out.println("==============>"+Thread.currentThread().getName() + ":");
            while (resultSet.next()) {
                firstValue = resultSet.getString(1);
                System.out.print(firstValue);
            }
            
            cyclicBarrier.await();  // (3)
        } catch (SQLException e) {
            e.printStackTrace();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (BrokenBarrierException e) {
            e.printStackTrace();
        } finally {
            try {
                if (connection != null) {
                    connection.close();
                }
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }
}

这样就感觉两者有一个潜在的通信机制,都到了就一起放开。只不过现在是六个线程参与计数了,CyclicBarrier 构造器传参应该是6(小于6也可能成功,大于6一定会一直阻塞)。

综合看了一下,我觉得最合适的还是 CountDownLatch。

这里主要是借单元测试多线程来加深下对并发相关知识点的理解,将其用于实践,来解决一些问题。关于这个单元测试多线程的问题很多人应该都知道,当初离职前面试过几个人,也问了这个问题,有几个说遇到过,我问为什么存在这个问题,你又是怎么解决的?结果没一个答得上来。

其实遇到问题是好事,都是成长的机会,每一个问题后面都隐藏着很多盲点,深挖下去一定收获颇多。