ThreadLocal源码阅读
一、ThreadLocal概述
在Java开发中,多线程是一个永远绕不开的话题。Java服务器中通常使用一组线程处理一个会话或一个连接,这一组线程一般具有父子关系,它们往往需要客户端传递的用户信息来完成业务逻辑,这些信息仅针对本次连接或会话有效。当然我们可以以参数的形式传递这些数据,但当参数异常复杂的时候,层间传递就变得很枯燥,有没有一种方法可以简化这种场景?当然有,他就是ThreadLocal技术,ThreadLocal是一种线程局部变量共享机制,它独立于语言,但本次仅讨论Java语言范畴。
使用ThreadLocal维护的变量在每个线程内部维护一个副本,对该变量的修改仅对该线程可见(指ThreadLocal.set(obj),对象成员修改除外),同样可以在线程内部的任意位置使用它。这样,既满足了线程之间的隔离要求,又减少了参数的传递工作,何乐而不为之!另外,我们上面提到,如果是一组线程(这一组线程具有父子关系)需要变量共享呢?没问题,Java给我们提供了InheritableThreadLocal,只要我们在父线程中设置了相关变量,子线程会自动继承这些变量的值,但本质上是子线程初始化自己的副本时使用了父线程的值,此后对各自副本的修改(指ThreadLocal.set(obj),对象成员修改除外)也仅在当前线程生效。
注意,网上很多博客写着ThreadLocal是用来解决多线程并发问题的,这种理解在我看来是错误的。多线程并发指的是多个线程对同一个临界区操作的互斥问题,或多个线程之间的同步问题,而ThreadLocal在每个线程都有一个副本,不存在互斥,也不存在同步,因此跟多线程并发问题无关。
二、ThreadLocal运用
上面简单介绍了ThreadLocal的性质,下面来看一下具体的运用场景:
-
Session管理:正如我们开篇提到的用户信息传递问题,本质上就是一个Session管理,它在一次会话开始的时候创建ThreadLocal变量保存所有全局信息,会话结束的时候释放ThreadLocal。如果线程的生命周期与会话的生命周期一致,则可以不用手动释放ThreadLocal变量,如果使用了线程池就必须在提交任务时手动初始化ThreadLocal,结束任务时手动清理ThreadLocal保存的数据,否则就可能使用的前一个会话遗留的脏数据。
@RestController public class UserController { private static final ThreadLocal<UserInfo> USER_INFO = new ThreadLocal<>(); private static final Executor EXECUTOR = Executors.newFixedThreadPool(10); @Autowired private UserService userService; @RequestMapping("/user") public User login(UserInfo userInfo) { USER_INFO.set(userInfo); User user = userService.login(); EXECUTOR.execute(()->{ //注意线程池需要手动管理ThreadLocal USER_INFO.set(userInfo); userService.doSomething(); USER_INFO.remove(); }); return user; } } -
连接管理:当我们有一个线程池用来处理一些远程任务,每个任务都需要与远程主机建立连接,那么为了减少频繁建立连接带来的性能开销,我们可以使用ThreadLocal来保存这些连接,使之与线程的生命周期一致,这样就避免了频繁建立远程连接带来的开销。
@RestController public class UserController { private static final ThreadLocal<UserInfo> USER_INFO = new ThreadLocal<>(); private static final ThreadLocal<RemoteConnection> REMOTE_CONNECTION = new ThreadLocal<>(); private static final Executor EXECUTOR = Executors.newFixedThreadPool(10); @Autowired private UserService userService; @RequestMapping("/user") public User login(UserInfo userInfo) { USER_INFO.set(userInfo); User user = userService.login(); EXECUTOR.execute(()->{ //注意线程池需要手动管理ThreadLocal USER_INFO.set(userInfo); //仅需要第一个任务初始化connection RemoteConnection connection = REMOTE_CONNECTION.get(); if(connection == null) { connection = RemoteUtil.getConnection(); REMOTE_CONNECTION.set(connection); } userService.doSomething(connection); USER_INFO.remove(); }); return user; } }
三、ThreadLocal原理
3.1 ThreadLocal原理
事实上ThreadLocal本身并不存储数据,它只是数据的 管家 。在Thread内部有threadLocals(对应ThreadLocal)和inheritableThreadLocals(对应InheritableThreadLocal)两个Map(它们也是hash map,但有自己的实现),它们维护者ThreadLocal/InheritableThreadLocal的副本。下面以threadLocals为例,threadLocals类型是ThreadLocal.ThreadLocalMap,它定义在ThreadLocal类中,其key是ThreadLocal变量的弱引用,value是对应副本值。
对于同一个ThreadLocal在不同Thread中,threadLocals中的key是同一个对象,这确保了同一个threadLocal变量能检索到 同一个类型 的value,但value在不同线程之间是独立的。
对于不同ThreadLocal在同一个线程中来说,不同的Thread Local对应 不同类型 的值,也就是threadLocals中的多个entry。
ThreadLocalMap中的key继承了WeakReference,因为ThreadLocal对于用户而言就是一个普通变量,它的生命周期应当符合一般变量行为。如果这里是强引用,那么即便用户将其引用置为null(或者方法返回、对象回收等等),该ThreadLocal对象可能依然无法被回收,因为还有其他线程的threadLocalMap中的entry对其有强引用。
将Entry中的key设为弱引用即可解决ThreadLocal GC回收的问题,但对应value又会带来内存泄露,对于value而言依然有thread -> threadLocalMap -> entry -> value这样的引用链存在,且该value永远无法被访问,直到线程结束。为解决这一问题,该Map中新增了部分对"stale entry"的回收逻辑。
3.2 ThreadLocal源码
点击展开代码
//一个变量对应一个ThreadLocal,该变量在不同线程之间有不同的副本
public class ThreadLocal<T> {
//每个线程都有一个threadLocalMap变量,其key为一个ThreadLocal变量,value为我们设置的值
//这里的threadLocalHashCode就充当了key的hashCode,它就是threadLocalMap中用于计算hash地址的依据
//ThreadLocal中计算hashCode的方法很简单,就是从0开始不断累加HASH_INCREMENT
//为保证不同ThreadLocal对象之间HashCode不同,其不断累加的临时变量nextHashCode是static的
private final int threadLocalHashCode = nextHashCode();
//用于计算threadLocalHashCode的【静态】临时变量,它保存的是下一个ThreadLocal变量的Hash地址
//并每创建一个ThreadLocal对象更新一次
private static AtomicInteger nextHashCode = new AtomicInteger();
//不断累加的HashCode差值,这个值可以保证在2^N范围类hashcode均匀分布,保证检索效率。
//该值与黄金分割和斐波那契散列算法有关,参考:https://zhuanlan.zhihu.com/p/40515974
private static final int HASH_INCREMENT = 0x61c88647;
//计算nextHashCode并返回当前对象的hashCode
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
//当某个线程下对应的ThreadLocal变量未设置时(threadLocalMap对应的k-v不存在)
//就会调用该方法返回一个初始值,默认返回null,可以继承重写该方法
protected T initialValue() {
return null;
}
//返回一个具有指定初始化supplier的ThreadLocal对象
//它会使用supplier获取(在initialValue方法中调用)初始值
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}
//默认无参构造方法
public ThreadLocal() {
}
//返回该ThreadLocal变量对应于当前线程的副本值
public T get() {
//获取当前线程
Thread t = Thread.currentThread();
//获取当前线程的ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null) {
//获取当前threadLocal对应的entry
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
//当前threadLocal对应的value
T result = (T)e.value;
return result;
}
}
//当前threadLocal对应值不存在就初始化,并返回初始化的值
return setInitialValue();
}
//如果当前线程的ThreadLocalMap为null或者当前ThreadLocal值不存在,
//就初始化map(如果需要的话)和当前threadLocal对应的k-v,并返回初始值
//注意ThreadLocalMap虽然在Thread中,但Thread并不初始化,全部交由ThreadLocal管理
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
//map存在就设置初始值
if (map != null)
map.set(this, value);
//map不存在就创建map并初始化
else
createMap(t, value);
return value;
}
//设置当前线程对应ThreadLocal的值
//过程与setInitialValue一致
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
//移除当前线程对应于该ThreadLocal的副本(建议不使用threadLocal的时候手动释放)
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
//获取线程关联的threadLocalMap
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
//创建一个ThreadLocalMap并使用给定初始化参数(this->firstValue)初始化
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
//创建一个从父线程继承下来的ThreadLocalMap,该Map包含父线程ThreadLocalMap的全部值
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
return new ThreadLocalMap(parentMap);
}
//InheritableThreadLocal中使用
T childValue(T parentValue) {
throw new UnsupportedOperationException();
}
//用supplier作为初始化器的ThreadLocal实现类
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
//用户提供的初始化器
private final Supplier<? extends T> supplier;
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}
//通过supplier获取初始值
@Override
protected T initialValue() {
return supplier.get();
}
}
//ThreadLocalMap的实现,仅用于保存ThreadLocal变量
static class ThreadLocalMap {
//Entry定义,注意key的弱引用
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
//初始容量
private static final int INITIAL_CAPACITY = 16;
//hash桶
private Entry[] table;
//存储的entity个数
private int size = 0;
//扩容阈值
private int threshold; // Default to 0
//设置扩容阈值为2/3
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
//使用线性探查法寻找下一个地址
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
//使用线性探查法寻找上一个地址(回收脏entry的时候使用)
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
//根据给定的初始值,初始化ThreadLocalMap(至少有一个值才会初始化)
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}
//拷贝构造
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];
for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
//使用线性探查法解决冲突,和hashMap不一样
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}
//根据key获取entry
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
//注意这里key的比较是==,不是equals,因为不同线程存的threadLocal就是同一个对象的引用
if (e != null && e.get() == key)
return e;
else //线性探查寻找entry
return getEntryAfterMiss(key, i, e);
}
//线性探查寻找entry
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
//找到了
if (k == key)
return e;
//key==null表示该value是一个泄露值(永远无法被使用),需要进行回收
if (k == null)
expungeStaleEntry(i);
else//继续探查下一个位置
i = nextIndex(i, len);
e = tab[i];
}//没有找到返回null
return null;
}
//添加一个键值对
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
//key存在直接替换值
if (k == key) {
e.value = value;
return;
}
//如果当前选定位置是一个“stale entry”,则按照一定的算法插入entry
//并清理一定范围内的“stale entry”
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
//如果冲突了就再次返回调用nextIndex探查下一个地址
}
tab[i] = new Entry(key, value);
int sz = ++size;
//如果???
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
//移除一个元素(不会导致"stale entry")
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear(); //清除value
//清理脏entry
expungeStaleEntry(i);
return;
}
}
}
// 插入新entry并清理周围"stale entry"
// 新的entry一定在staleSlot位置
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
//回溯查找泄露entry的最小位置,直到遇见一个未占用的hash桶为止
//中途记录下泄露entry的最小位置,作为清理的起始点
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null;
i = prevIndex(i, len)) {
if (e.get() == null)
slotToExpunge = i;
}
//向后查找泄露entry的最大位置,直到遇见一个未占用的hash桶或找到一个可替换的entry为止
//并在合适的时机插入新entry
//注意:从i到下一个未占用的hash桶之间是必须要遍历的,这样可以检查我们要插入的key先前
//是否出现过,如果出现过就必须替换value,这样才能保证key的唯一性
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
//如果发现有可替换的entry
if (k == key) {
e.value = value; //替换value
//把当前entry换到staleSlot处,注意staleSlot是该key第一个可用位置
//交换既满足了key的唯一性,又尽可能保证key的检索效率
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
//如果一直没有找到脏entry,就以当前位置为起点清理
//注意,当前位置是和staleSlot交换过的,所以当前位置一定是一个脏entry
if (slotToExpunge == staleSlot)
slotToExpunge = i;
//清理脏entry
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
//如果前向搜索没有发现脏enyrt,并且当前节点是脏entry,就以当前位置为清理的起始点
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}// end for
//如果没有找到可替换的entry,将staleSlot位置释放,并重新填充一个新entry
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
//如果遍历过程中发现了其他脏entry,则清理之
//slotToExpunge一直是清理的起始点,起始点尽可能小,清理范围尽可能大
if (slotToExpunge != staleSlot)
//两次清理
//第一次:清理slotToExpunge到下一个null slot之间的脏entry
//第二次:从下一个null的下一个位置开始,最少扫描log2(len)次
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
//清理stale entry
//参数:staleSlot 当前stale entry的位置
//返回:下一个空hash桶(slot)的位置,这个区间内所有的脏entry都会被清理掉
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// 清理当前指定位置staleSlot上的stale entry
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
//向后遍历直到下一个空桶位置
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
//遇见stale entry则清理之
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
//如果遇见正常的entry就重新计算地址
//因为前面释放了stale entry,当前entry很大可能有更优的存储位置
//减少e的线性探查次数,提升访问效率
int h = k.threadLocalHashCode & (len - 1);
//e的hash地址不在当前位置,证明e一定是因为hash地址冲突而放到了这里
//现在释放了之前部分stale entry,e很可能有更优秀的位置
//如果h==i,则当前位置就是e的最佳位置,不做任何操作
if (h != i) {
//释放当前位置
tab[i] = null;
//从h向后重新给e找位置
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
//返回下一个空slot的位置
return i;
}
//从i的下一个位置开始清理
//n控制扫描次数,扫描log2(n)轮,每一轮根据情况扫描一段或一个,整体时间复杂度为nlog2(n)
//兼顾扫描效率和清理效果
//如果有stale entry被删除则返回true
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
//如果发现stale entry
if (e != null && e.get() == null) {
n = len;
removed = true;
//移除[i, next null slot]之间的脏entry,i为next null slot位置
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
//rehash
private void rehash() {
expungeStaleEntries();
// size >= 5 / 12(len) ?
if (size >= threshold - threshold / 4)
resize();
}
//扩容,新容量为旧容量的两倍
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
//遍历旧table(将旧table中的内容移动到新table)
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
//如果还遇到stale entry,直接unlink
if (k == null) {
e.value = null; // Help the GC
}
//计算e在新表中的位置,并放入其中
else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
//更新threshold
setThreshold(newLen);
size = count;
table = newTab;
}
// 清理所有脏entry,并重新计算每一个有效元素的最佳位置
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
}