ThreadLocal源码阅读

一、ThreadLocal概述

        在Java开发中,多线程是一个永远绕不开的话题。Java服务器中通常使用一组线程处理一个会话或一个连接,这一组线程一般具有父子关系,它们往往需要客户端传递的用户信息来完成业务逻辑,这些信息仅针对本次连接或会话有效。当然我们可以以参数的形式传递这些数据,但当参数异常复杂的时候,层间传递就变得很枯燥,有没有一种方法可以简化这种场景?当然有,他就是ThreadLocal技术,ThreadLocal是一种线程局部变量共享机制,它独立于语言,但本次仅讨论Java语言范畴。

        使用ThreadLocal维护的变量在每个线程内部维护一个副本,对该变量的修改仅对该线程可见(指ThreadLocal.set(obj),对象成员修改除外),同样可以在线程内部的任意位置使用它。这样,既满足了线程之间的隔离要求,又减少了参数的传递工作,何乐而不为之!另外,我们上面提到,如果是一组线程(这一组线程具有父子关系)需要变量共享呢?没问题,Java给我们提供了InheritableThreadLocal,只要我们在父线程中设置了相关变量,子线程会自动继承这些变量的值,但本质上是子线程初始化自己的副本时使用了父线程的值,此后对各自副本的修改(指ThreadLocal.set(obj),对象成员修改除外)也仅在当前线程生效

        注意,网上很多博客写着ThreadLocal是用来解决多线程并发问题的,这种理解在我看来是错误的。多线程并发指的是多个线程对同一个临界区操作的互斥问题,或多个线程之间的同步问题,而ThreadLocal在每个线程都有一个副本,不存在互斥,也不存在同步,因此跟多线程并发问题无关。

二、ThreadLocal运用

上面简单介绍了ThreadLocal的性质,下面来看一下具体的运用场景:

  1. Session管理:正如我们开篇提到的用户信息传递问题,本质上就是一个Session管理,它在一次会话开始的时候创建ThreadLocal变量保存所有全局信息,会话结束的时候释放ThreadLocal。如果线程的生命周期与会话的生命周期一致,则可以不用手动释放ThreadLocal变量,如果使用了线程池就必须在提交任务时手动初始化ThreadLocal,结束任务时手动清理ThreadLocal保存的数据,否则就可能使用的前一个会话遗留的脏数据。

    @RestController
    public class UserController {
        private static final ThreadLocal<UserInfo> USER_INFO = new ThreadLocal<>();
        private static final Executor EXECUTOR = Executors.newFixedThreadPool(10);
        
        @Autowired
        private UserService userService;
        
        @RequestMapping("/user") 
        public User login(UserInfo userInfo) {
            USER_INFO.set(userInfo);
            User user = userService.login();
            EXECUTOR.execute(()->{
                //注意线程池需要手动管理ThreadLocal
                USER_INFO.set(userInfo);
                userService.doSomething();
                USER_INFO.remove();
            });
            return user;
        }
    }
    
  2. 连接管理:当我们有一个线程池用来处理一些远程任务,每个任务都需要与远程主机建立连接,那么为了减少频繁建立连接带来的性能开销,我们可以使用ThreadLocal来保存这些连接,使之与线程的生命周期一致,这样就避免了频繁建立远程连接带来的开销。

    @RestController
    public class UserController {
        private static final ThreadLocal<UserInfo> USER_INFO = new ThreadLocal<>();
        private static final ThreadLocal<RemoteConnection> REMOTE_CONNECTION = new ThreadLocal<>();
        private static final Executor EXECUTOR = Executors.newFixedThreadPool(10);
    
        @Autowired
        private UserService userService;
    
        @RequestMapping("/user")
        public User login(UserInfo userInfo) {
            USER_INFO.set(userInfo);
            User user = userService.login();
            EXECUTOR.execute(()->{
                //注意线程池需要手动管理ThreadLocal
                USER_INFO.set(userInfo);
                //仅需要第一个任务初始化connection
                RemoteConnection connection = REMOTE_CONNECTION.get();
                if(connection == null) {
                    connection = RemoteUtil.getConnection();
                    REMOTE_CONNECTION.set(connection);
                }
                userService.doSomething(connection);
                USER_INFO.remove();
            });
            return user;
        }
    }
    
  3. ThreadLocal在Spring事务管理中的应用

三、ThreadLocal原理

3.1 ThreadLocal原理

        事实上ThreadLocal本身并不存储数据,它只是数据的 管家 。在Thread内部有threadLocals(对应ThreadLocal)和inheritableThreadLocals(对应InheritableThreadLocal)两个Map(它们也是hash map,但有自己的实现),它们维护者ThreadLocal/InheritableThreadLocal的副本。下面以threadLocals为例,threadLocals类型是ThreadLocal.ThreadLocalMap,它定义在ThreadLocal类中,其key是ThreadLocal变量的弱引用,value是对应副本值。

        对于同一个ThreadLocal在不同Thread中,threadLocals中的key是同一个对象,这确保了同一个threadLocal变量能检索到 同一个类型 的value,但value在不同线程之间是独立的。

        对于不同ThreadLocal在同一个线程中来说,不同的Thread Local对应 不同类型 的值,也就是threadLocals中的多个entry。

ThreadLocal与多线程之间的关系

        ThreadLocalMap中的key继承了WeakReference,因为ThreadLocal对于用户而言就是一个普通变量,它的生命周期应当符合一般变量行为。如果这里是强引用,那么即便用户将其引用置为null(或者方法返回、对象回收等等),该ThreadLocal对象可能依然无法被回收,因为还有其他线程的threadLocalMap中的entry对其有强引用。

        将Entry中的key设为弱引用即可解决ThreadLocal GC回收的问题,但对应value又会带来内存泄露,对于value而言依然有thread -> threadLocalMap -> entry -> value这样的引用链存在,且该value永远无法被访问,直到线程结束。为解决这一问题,该Map中新增了部分对"stale entry"的回收逻辑。

ThreadLocal弱引用

3.2 ThreadLocal源码

点击展开代码
//一个变量对应一个ThreadLocal,该变量在不同线程之间有不同的副本
public class ThreadLocal<T> {
    //每个线程都有一个threadLocalMap变量,其key为一个ThreadLocal变量,value为我们设置的值
    //这里的threadLocalHashCode就充当了key的hashCode,它就是threadLocalMap中用于计算hash地址的依据
    //ThreadLocal中计算hashCode的方法很简单,就是从0开始不断累加HASH_INCREMENT
    //为保证不同ThreadLocal对象之间HashCode不同,其不断累加的临时变量nextHashCode是static的
    private final int threadLocalHashCode = nextHashCode();

    //用于计算threadLocalHashCode的【静态】临时变量,它保存的是下一个ThreadLocal变量的Hash地址
    //并每创建一个ThreadLocal对象更新一次
    private static AtomicInteger nextHashCode = new AtomicInteger();
    
    //不断累加的HashCode差值,这个值可以保证在2^N范围类hashcode均匀分布,保证检索效率。
    //该值与黄金分割和斐波那契散列算法有关,参考:https://zhuanlan.zhihu.com/p/40515974
    private static final int HASH_INCREMENT = 0x61c88647;
    
    //计算nextHashCode并返回当前对象的hashCode
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }
    
    //当某个线程下对应的ThreadLocal变量未设置时(threadLocalMap对应的k-v不存在)
    //就会调用该方法返回一个初始值,默认返回null,可以继承重写该方法
    protected T initialValue() {
        return null;
    }
    
    //返回一个具有指定初始化supplier的ThreadLocal对象
    //它会使用supplier获取(在initialValue方法中调用)初始值
    public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
        return new SuppliedThreadLocal<>(supplier);
    }
    
    //默认无参构造方法
    public ThreadLocal() {
    }
    
    //返回该ThreadLocal变量对应于当前线程的副本值
    public T get() {
        //获取当前线程
        Thread t = Thread.currentThread();
        //获取当前线程的ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //获取当前threadLocal对应的entry
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                //当前threadLocal对应的value
                T result = (T)e.value;
                return result;
            }
        }
      	//当前threadLocal对应值不存在就初始化,并返回初始化的值
        return setInitialValue();
    }
    
    //如果当前线程的ThreadLocalMap为null或者当前ThreadLocal值不存在,
    //就初始化map(如果需要的话)和当前threadLocal对应的k-v,并返回初始值
    //注意ThreadLocalMap虽然在Thread中,但Thread并不初始化,全部交由ThreadLocal管理
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        //map存在就设置初始值
        if (map != null)
            map.set(this, value);
        //map不存在就创建map并初始化
        else
            createMap(t, value);
        return value;
    }
    
    //设置当前线程对应ThreadLocal的值
    //过程与setInitialValue一致
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    
    //移除当前线程对应于该ThreadLocal的副本(建议不使用threadLocal的时候手动释放)
    public void remove() {
        ThreadLocalMap m = getMap(Thread.currentThread());
        if (m != null)
            m.remove(this);
    }
    
    //获取线程关联的threadLocalMap
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    
    //创建一个ThreadLocalMap并使用给定初始化参数(this->firstValue)初始化
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
    
    //创建一个从父线程继承下来的ThreadLocalMap,该Map包含父线程ThreadLocalMap的全部值
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }
    
    //InheritableThreadLocal中使用
    T childValue(T parentValue) {
        throw new UnsupportedOperationException();
    }
    
    //用supplier作为初始化器的ThreadLocal实现类
    static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
    
        //用户提供的初始化器
        private final Supplier<? extends T> supplier;
    
        SuppliedThreadLocal(Supplier<? extends T> supplier) {
            this.supplier = Objects.requireNonNull(supplier);
        }
    
        //通过supplier获取初始值
        @Override
        protected T initialValue() {
            return supplier.get();
        }
    }
    
    //ThreadLocalMap的实现,仅用于保存ThreadLocal变量
    static class ThreadLocalMap {
    
        //Entry定义,注意key的弱引用
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;
    
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
    
        //初始容量
        private static final int INITIAL_CAPACITY = 16;
    
        //hash桶
        private Entry[] table;
    
        //存储的entity个数
        private int size = 0;
    
        //扩容阈值
        private int threshold; // Default to 0
    
        //设置扩容阈值为2/3
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }
    
        //使用线性探查法寻找下一个地址
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }
    
        //使用线性探查法寻找上一个地址(回收脏entry的时候使用)
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }
    
        //根据给定的初始值,初始化ThreadLocalMap(至少有一个值才会初始化)
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }
    
        //拷贝构造
        private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];
    
            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            //使用线性探查法解决冲突,和hashMap不一样
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }
    
        //根据key获取entry
        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            //注意这里key的比较是==,不是equals,因为不同线程存的threadLocal就是同一个对象的引用
            if (e != null && e.get() == key)
                return e;
            else //线性探查寻找entry
                return getEntryAfterMiss(key, i, e);
        }
    
        //线性探查寻找entry
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;
    
            while (e != null) {
                ThreadLocal<?> k = e.get();
                //找到了
                if (k == key)
                    return e;
                //key==null表示该value是一个泄露值(永远无法被使用),需要进行回收
                if (k == null)
                    expungeStaleEntry(i);
                else//继续探查下一个位置
                    i = nextIndex(i, len);
                e = tab[i];
            }//没有找到返回null
            return null;
        }
    
        //添加一个键值对
        private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
    
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
    
                //key存在直接替换值
                if (k == key) {
                    e.value = value;
                    return;
                }
    
                //如果当前选定位置是一个“stale entry”,则按照一定的算法插入entry
                //并清理一定范围内的“stale entry”
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
                //如果冲突了就再次返回调用nextIndex探查下一个地址
            }
    
            tab[i] = new Entry(key, value);
            int sz = ++size;
            //如果???
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }
    
        //移除一个元素(不会导致"stale entry")
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear(); //清除value
                    //清理脏entry
                    expungeStaleEntry(i);
                    return;
                }
            }
        }
    
        // 插入新entry并清理周围"stale entry"
        // 新的entry一定在staleSlot位置
        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;
    
            //回溯查找泄露entry的最小位置,直到遇见一个未占用的hash桶为止
            //中途记录下泄露entry的最小位置,作为清理的起始点
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; 
                 i = prevIndex(i, len)) {
                if (e.get() == null)
                    slotToExpunge = i;
            }
    
            //向后查找泄露entry的最大位置,直到遇见一个未占用的hash桶或找到一个可替换的entry为止
            //并在合适的时机插入新entry
            //注意:从i到下一个未占用的hash桶之间是必须要遍历的,这样可以检查我们要插入的key先前
            //是否出现过,如果出现过就必须替换value,这样才能保证key的唯一性
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
    
    			//如果发现有可替换的entry
                if (k == key) {
                    e.value = value; //替换value
                    //把当前entry换到staleSlot处,注意staleSlot是该key第一个可用位置
                    //交换既满足了key的唯一性,又尽可能保证key的检索效率
                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;
    
                    //如果一直没有找到脏entry,就以当前位置为起点清理
                    //注意,当前位置是和staleSlot交换过的,所以当前位置一定是一个脏entry
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
    				//清理脏entry
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }
    
                //如果前向搜索没有发现脏enyrt,并且当前节点是脏entry,就以当前位置为清理的起始点
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }// end for
    
            //如果没有找到可替换的entry,将staleSlot位置释放,并重新填充一个新entry
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);
    
            //如果遍历过程中发现了其他脏entry,则清理之
            //slotToExpunge一直是清理的起始点,起始点尽可能小,清理范围尽可能大
            if (slotToExpunge != staleSlot)
                //两次清理
                //第一次:清理slotToExpunge到下一个null slot之间的脏entry
                //第二次:从下一个null的下一个位置开始,最少扫描log2(len)次
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
    
        //清理stale entry
        //参数:staleSlot 当前stale entry的位置
        //返回:下一个空hash桶(slot)的位置,这个区间内所有的脏entry都会被清理掉
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
    
            // 清理当前指定位置staleSlot上的stale entry
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;
    
            //向后遍历直到下一个空桶位置
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                //遇见stale entry则清理之
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    //如果遇见正常的entry就重新计算地址
                    //因为前面释放了stale entry,当前entry很大可能有更优的存储位置
                    //减少e的线性探查次数,提升访问效率
                    int h = k.threadLocalHashCode & (len - 1);
                    //e的hash地址不在当前位置,证明e一定是因为hash地址冲突而放到了这里
                    //现在释放了之前部分stale entry,e很可能有更优秀的位置
                    //如果h==i,则当前位置就是e的最佳位置,不做任何操作
                    if (h != i) {
                        //释放当前位置
                        tab[i] = null;
                        //从h向后重新给e找位置
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            //返回下一个空slot的位置
            return i;
        }
    
        //从i的下一个位置开始清理
        //n控制扫描次数,扫描log2(n)轮,每一轮根据情况扫描一段或一个,整体时间复杂度为nlog2(n)
        //兼顾扫描效率和清理效果
        //如果有stale entry被删除则返回true
        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                //如果发现stale entry
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    //移除[i, next null slot]之间的脏entry,i为next null slot位置
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

  		//rehash
        private void rehash() {
            expungeStaleEntries(); 
            // size >= 5 / 12(len) ?
            if (size >= threshold - threshold / 4)
                resize();
        }

        //扩容,新容量为旧容量的两倍
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;
    
            //遍历旧table(将旧table中的内容移动到新table)
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    //如果还遇到stale entry,直接unlink
                    if (k == null) {
                        e.value = null; // Help the GC
                    } 
                    //计算e在新表中的位置,并放入其中
                    else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
    
            //更新threshold
            setThreshold(newLen);
            size = count;
            table = newTab;
        }
    
        // 清理所有脏entry,并重新计算每一个有效元素的最佳位置
        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }
    }