CountDownLatch源码分析

基于jdk1.8源码进行分析的。

CountDownLatch是JDK 5+里面闭锁的一个实现,相当于在开放之前,所有的线程都会被阻塞,当开放之后,所有线程都进行执行,但是开放之后就不能在关闭。它有一个正数计数器,countDown方法对计数器做减操作,await方法等待计数器达到0,所有await的线程都会阻塞直到计数器为0或者等待线程中断或者超时。下面我们看一下DEMO演示。

package demo;

import java.util.concurrent.CountDownLatch;

/**
 * CountDownLatch测试例程
 * @ClassName:   CountDownLatchDemo  
 * @Description: TODO
 * @author       BurgessLee
 * @date         2019年4月26日  
 *
 */
public class CountDownLatchDemo {
	
	public static void main(String[] args) {
		CountDownLatch countDownLatch = new CountDownLatch(1);
		new Thread(new Runnable() {
			
			@Override
			public void run() {
				System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
				try {
					countDownLatch.await();
					System.out.println(Thread.currentThread().getName()+"-end....count=" + countDownLatch.getCount());
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
			}
		},"Thread0").start();
		new Thread(new Runnable() {
					
			@Override
			public void run() {
				System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
				try {
					countDownLatch.await();
					System.out.println(Thread.currentThread().getName()+"-end....count=" + countDownLatch.getCount());
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
			}
		},"Thread1").start();
		new Thread(new Runnable() {
			
			@Override
			public void run() {
				System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
				try {
					countDownLatch.await();
					System.out.println(Thread.currentThread().getName()+"-end....count=" + countDownLatch.getCount());
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
			}
		},"Thread2").start();
		new Thread(new Runnable() {
			
			@Override
			public void run() {
				System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
				try {
					countDownLatch.await();
					System.out.println(Thread.currentThread().getName()+"-end....count=" + countDownLatch.getCount());
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
			}
		},"Thread3").start();
		
		try {
			Thread.sleep(2000);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		
		System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
		countDownLatch.countDown();
		
		try {
			Thread.sleep(2000);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		
		System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
		try {
			Thread.sleep(2000);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		System.out.println(Thread.currentThread().getName() +" ending...............");
	}
	
}

以上是整个测试例程,输出结果如下:

Thread0-start....count=1
Thread1-start....count=1
Thread2-start....count=1
Thread3-start....count=1
main-start....count=1
Thread0-end....count=0
Thread1-end....count=0
Thread2-end....count=0
Thread3-end....count=0
main-start....count=0

根据例程我们大致应该明白了大致用法。下面我们看一下该类的源码。

类声明

public class CountDownLatch

构造函数

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

做了一个数据校验的操作,然后同时初始化的是Sync的实例。下面我们看一下Sync。

内部类

Sync

 同样看到是AQS的实现类,经过之前很多代码的分析估计已经很熟悉不过了。

    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

成员属性

private final Sync sync;

成员方法

 toString()

    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }

第一个方法就是toString方法,没想到吧。23333。我们都知道这里是通过Object来的,进行了重写,不再过多言语介绍了。

await()

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

可以看见是通过调用sync的acquireSharedInterruptibly()方法实现的,并且args为1。在AQS中我们看到具体代码实现。

    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        //线程中断,直接抛出异常
        if (Thread.interrupted())
            throw new InterruptedException();
        //小于0,说明失败,然后执行相应动作
        //如果大于0,那么此时状态已经为0,此时所有线程放行
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }

里面涉及到了tryAcquireShared和doAcquireSharedInterruptibly两个方法。我们看一下。

tryAcquireShared(int acquires)

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

当然实现在Sync中重写改方法。可以看到简单粗暴,直接拿state状态值和0作比较,返回值1和-1。

    //在AQS同步队列进行等待,并不断的自旋检测是否需要唤醒
    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

该方法整个实现过程详见AQS,分析到现在此时我做了一个文件对比,不知道你是否也有同样的疑问。

CountDownLatch源码分析

在AQS中存在一下几个方法,如下图所示:

CountDownLatch源码分析

三者的区别,第一个和第三个很容易区分,有时长的一个限制,那么第一个和第二个区别在上上图文件对比已经看到足够清除了,就是关于线程中断的处理是不一样的。一个有关于线程标志位的相应处理,一个只是抛出了对应的线程中断异常。上面就当做一个插曲吧,下面我们继续看源码。

countDown()

    public void countDown() {
        sync.releaseShared(1);
    }

涉及到调用的方法releaseShared,我们继续往下看。

    //尝试释放共享锁
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }
    protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    private void doReleaseShared() {

        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

以上方法在前面已经分析过了,不在一一介绍。这里如果想了解,可以去看Semaphore源码分析

处于好奇,将上面的DEMO改造一下,如下:

package demo;

import java.util.concurrent.CountDownLatch;

/**
 * CountDownLatch测试例程
 * @ClassName:   CountDownLatchDemo  
 * @Description: TODO
 * @author       BurgessLee
 * @date         2019年4月26日  
 *
 */
public class CountDownLatchDemo {
	
	public static void main(String[] args) {
		CountDownLatch countDownLatch = new CountDownLatch(1);
		Thread t0 = new Thread(new Runnable() {
			
			@Override
			public void run() {
				System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
				try {
					countDownLatch.await();
					System.out.println(Thread.currentThread().getName()+"-end....count=" + countDownLatch.getCount());
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
			}
		},"Thread0");
		Thread t1 = new Thread(new Runnable() {
					
			@Override
			public void run() {
				System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
				try {
					countDownLatch.await();
					System.out.println(Thread.currentThread().getName()+"-end....count=" + countDownLatch.getCount());
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
			}
		},"Thread1");
		Thread t2 = new Thread(new Runnable() {
			
			@Override
			public void run() {
				System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
				try {
					countDownLatch.await();
					System.out.println(Thread.currentThread().getName()+"-end....count=" + countDownLatch.getCount());
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
			}
		},"Thread2");
		Thread t3 = new Thread(new Runnable() {
			
			@Override
			public void run() {
				System.out.println(Thread.currentThread().getName() +"-start....count=" +countDownLatch.getCount());
				try {
					countDownLatch.await();
					System.out.println(Thread.currentThread().getName()+"-end....count=" + countDownLatch.getCount());
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
			}
		},"Thread3");
		t0.start();
		t1.start();
		t2.start();
		t3.start();
		try {
			Thread.sleep(2000);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		System.out.println("t0中断");
		t0.interrupt();
		System.out.println("t0中断成功");
		try {
			Thread.sleep(1000);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		System.out.println(Thread.currentThread().getName() +"-count=" +countDownLatch.getCount());
		countDownLatch.countDown();
		
		try {
			Thread.sleep(2000);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		System.out.println(Thread.currentThread().getName() +" ending...............");
	}
	
}

可以看到区别就是加上了线程中断,输出结果如下:

Thread0-start....count=1
Thread3-start....count=1
Thread2-start....count=1
Thread1-start....count=1
t0中断
t0中断成功
java.lang.InterruptedException
	at java.util.concurrent.locks.AbstractQueuedSynchronizer.doAcquireSharedInterruptibly(AbstractQueuedSynchronizer.java:998)
	at java.util.concurrent.locks.AbstractQueuedSynchronizer.acquireSharedInterruptibly(AbstractQueuedSynchronizer.java:1304)
	at java.util.concurrent.CountDownLatch.await(CountDownLatch.java:231)
	at demo.CountDownLatchDemo$1.run(CountDownLatchDemo.java:23)
	at java.lang.Thread.run(Thread.java:745)
main-count=1
Thread3-end....count=0
Thread2-end....count=0
Thread1-end....count=0
main ending...............

从结果来看,相对比少了

Thread0-end....count=0

这里想说明什么呢?因为整体做了try-catch处理,所以程序正常执行了,而且await方法当检测线程出现中断的时候,那么此时抛出了异常。

以上就是此次CountDownLatch类的源码分析过程。如果有不对的地方还请指正。