【Java8源码分析】线程-ThreadLocal的全面剖析

时间:2022-10-31 17:34:29

一、背景

ThreadLocal类顾名思义就是,申明为ThreadLocal的变量,对于不同线程来说都是独立的。

下面是一个例子:

public class Test {

public static void main(String[] args) {
ThreadLocalTest threadLocalTest = new ThreadLocalTest();

for(int i = 0; i < 3; i++) {
TaskTest taskTest = new TaskTest(threadLocalTest);
Thread t = new Thread(taskTest);
t.start();
try {
Thread.sleep(200);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}

static class TaskTest implements Runnable{
ThreadLocalTest test;

public TaskTest(ThreadLocalTest test) {
this.test = test;
}

@Override
public void run() {
int tmp = test.localCnt.get();
test.localCnt.set(tmp + 1);
test.shareCnt += 1;
System.out.println(Thread.currentThread().getName());
System.out.println("LocalCnt:" + test.localCnt.get());
System.out.println("SharedCnt:" + test.shareCnt);
}
}

static class ThreadLocalTest
{
ThreadLocal<Integer> localCnt = new ThreadLocal<Integer>() {
public Integer initialValue() {
return 0;
}
};

int shareCnt = 0;

public ThreadLocalTest() {

}
}
}

输出结果:

Thread-0
LocalCnt:1
SharedCnt:1
Thread-1
LocalCnt:1
SharedCnt:2
Thread-2
LocalCnt:1
SharedCnt:3

基本原理:ThreadLocal会为每一个线程提供一个独立的变量副本,从而隔离了多个线程对数据的访问冲突。因为每一个线程都拥有自己的变量副本,从而也就没有必要对该变量进行同步了。ThreadLocal提供了线程安全的共享对象,在编写多线程代码时,可以把不安全的变量封装进ThreadLocal。

二、存储结构

在ThreadLocal类中定义了一个重要静态内部类,ThreadLocalMap,用来存储每个线程的局部变量,代码如下

    static class ThreadLocalMap {

// Entry继承自WeakReference类,是存储线程私有变量的数据结构
// ThreadLocal实例作为引用,意味着如果ThreadLocal实例为null
// 就可以从table中删除对应的Entry。
static class Entry extends WeakReference<ThreadLocal<?>> {

Object value;

// 把ThreadLocal与value封装成Entry
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

// 数组初始大小为16
private static final int INITIAL_CAPACITY = 16;

// 存储数组
private Entry[] table;

private int size = 0;
private int threshold; // 默认为0
private void setThreshold(int len) {
threshold = len * 2 / 3;
}

private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

// 构造函数
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}

private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];

for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}

private void set(ThreadLocal<?> key, Object value) {

Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

if (k == key) {
e.value = value;
return;
}

if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}

tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
}

三、主要方法

public class ThreadLocal<T> {

// 跟hash值相关的部分
private final int threadLocalHashCode = nextHashCode();
private static AtomicInteger nextHashCode =
new AtomicInteger();
private static final int HASH_INCREMENT = 0x61c88647;
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}

// 此方法在每个线程中最多执行一次,如果第一次执行get(),会调用此方法
// 如果在第一次执行get()之前已经调用过set(),则此方法永远不执行
// 可以看到默认返回null值,为了避免不必要错误,最好重写此方法
protected T initialValue() {
return null;
}

// 构造函数
public ThreadLocal() {
}

// 获取线程所属的值
public T get() {

// 获取当前线程
Thread t = Thread.currentThread();

// 每个线程有维护一个ThreadLocalMap变量,调用getMap获取
ThreadLocalMap map = getMap(t);

// 如果map不为空
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);

// 如果map中已经有该ThreadLocal的值,返回
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}

// 在没有map或map中没有添加该ThreadLocal时调用初始化
return setInitialValue();
}

// 初始化
private T setInitialValue() {
// 调用initialValue获取默认值
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
// 如果Thread中并没有map,则新建一个
// 这里注意,是每个Thread维护一个ThreadLocalMap
createMap(t, value);
return value;
}

// 赋值
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
// 这里同样有可能调用创建ThreadLocalMap
createMap(t, value);
}

public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}

// 返回Thread中维护的TreadLocalMap
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

// 如果Thread中并没有map,则新建一个
// 这里注意,是每个Thread维护一个ThreadLocalMap
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
}

四、总结

ThreadLocal类最重要的一个概念是,其原理是通过一个ThreadLocal的静态内部类ThreadLocalMap实现,但是实际中,ThreadLocal不保存ThreadLocalMap,而是有每个Thread内部维护ThreadLocal.ThreadLocalMap threadLocals一份数据结构。

这里画张图更容易理解,假如我们有如下的代码

class ThreadLocalDemo
{
ThreadLocal<Integer> localA = new ThreadLocal<Integer>();
ThreadLocal<Integer> localB = new ThreadLocal<Integer>();
}

在多线程环境下,数据结构应该是如下图所示


【Java8源码分析】线程-ThreadLocal的全面剖析