转载

浅析 ThreadLocal 原理

文章目录(站在前人的肩膀上总结了这篇博文,因水平有限,如有错误欢迎直接对喷):

  1. ThreadLocal 解决什么问题
  2. ThreadLocal 用法及其使用场景
  3. ThreadLocal 实现原理
  4. ThreadLocal 不支持继承性
  5. InheritableThreadLocal 类的作用

1 ThreadLocal 解决什么问题

关于 ThreadLocal 原理的博文已经不计其数了,百度查到的文章大多对 ThreadLocal 的使用场景和解决的问题都掺杂些许误区,从 google 搜索的博文大多都会指出误区进行纠正,网上最常见的误区有如下两点:

  1. ThreadLocal 为解决多线程程序的并发问题提供了一种新的思路
  2. ThreadLocal 的目的是为了解决多线程访问资源时的共享问题

那么为什么会指出这两点理解是有误的?我们不妨直接看下 ThreadLocal 的源码描述:

This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one (via its get or set method) has its own, independently initialized copy of the variable. ThreadLocal instances are typically private static fields in classes that wish to associate state with a thread (e.g., a user ID or Transaction ID).
Each thread holds an implicit reference to its copy of a thread-local variable as long as the thread is alive and the ThreadLocal instance is accessible; after a thread goes away, all of its copies of thread-local instances are subject to garbage collection (unless other references to these copies exist).

实际上 ThreadLocal 并不是 J.U.C 包提供的,它来源于 java.lang 包。源码上指出,ThreadLocal 提供了线程本地变量,每个使用该变量的线程都有自己独立初始化的变量副本。不会于其他线程发生冲突,实现了线程间数据的隔离。只要线程存活并且 ThreadLocal 实例可以访问,每个线程都会保存对其线程局部变量副本的隐式引用。线程消失后,线程本地实例的所有副本都会被垃圾收集(除非存在对这些副本的其他引用),因此 ThreadLocal 实例通常被设为 private static 字段。

因此我们可以发现,使用 ThreadLocal 并不会存在 数据共享问题 ,也就不存在同步问题,所以将 ThreadLocal 与多线程并发问题进行比较似乎有些有些牛头不对马嘴,也难怪有很多博文对此进行了批判。

2 ThreadLocal 用法及其使用场景

2.1 ThreadLocal 用法

这里用 《并发编程之美》 中的一个简单例子演示以下 ThreadLocal 的使用,下面代码开启了两个线程,每个线程内部都设置了本地变量,然后调用 localVariable 打印当前本地变量的值。

public class ThreadLocalTest {
    // 1 创建 ThreadLocal 变量
    private static ThreadLocal<String> localVariable = new ThreadLocal<>();     
    // 2 创建线程 one
    public static void main(String[] args) {
        Thread threadOne = new Thread(() -> {
            // 2.1 设置线程 one 中本地变量 localVariable 的值
            localVariable.set("threadOne local variable");
            // 2.2 打印当前线程本地内存中 localVariable 变量的值           
            System.out.println("threadOne: " + localVariable.get());
            // 2.3 打印本地变量             
            System.out.println("threadOne remove after" + ": " + localVariable.get());
        });
        // 3 创建线程 two
        Thread threadTwo = new Thread(() -> {
            // 3.1 设置线程 two 中本地变量 localVariable 的值
            localVariable.set("threadTwo local variable");
            // 3.2 打印当前线程本地内存中 localVariable 变量的值
            System.out.println("threadTwo: " + localVariable.get());
            // 3.3 打印本地变量 
            System.out.println("threadTwo remove after" + ": " + localVariable.get());
        });
        threadOne.start();
        threadTwo.start();
    }
}
复制代码

运行结果如下:

threadOne: threadOne local variable
threadOne remove after: threadOne local variable
threadTwo: threadTwo local variable
threadTwo remove after: threadTwo local variable
复制代码

可以发现,调用线程通过 set 方法设置了 localVariable 的值(实际上设置的是调用线程本地内存中的一个副本,原理在下一节给出),而这个值其他线程是访问不了的。

2.2 ThreadLocal 使用场景

个人在实际工作和生活中遇见使用 ThreadLocal 的场景有如(如果纯粹写 CURD,用到 ThreadLocal 的场景应该几乎没有):

  1. 使用 ThreadLocal 包裹一些线程不安全的工具类,如 Random,SimpleDateFormat 等。
  2. 使用 ThreadLocal 与 AOP 进行结合,例如在前置通知设置当前线程的一些重要信息,在后置通知获取这些重要信息。
  3. 数据库连接池中的 Connection 由 ThreadLocal 来管理,保证线程中多个 dao 操作,用的都是同一个 Connection 使事务得以保证。
  4. Spring 能实现在多线程环境下,将各个线程的 request 进行隔离,且准确无误的进行注入,奥秘就是 ThreadLocal。

3 ThreadLocal 实现原理

在讲原理之前,先贴出在掘金上看到表述的非常清晰的一张结构图(图中实线是强引用,虚线是弱引用):

浅析 ThreadLocal 原理

根据图片可以看到,每个 Thread 对象内部都会维护一个 ThreadLocalMap 类型的成员变量(threadLocals 字段),存储在 ThreadLocalMap 内部是一个以 Entry 为元素的 table 数据,Entry 是一个 key-value 结构,key 为 ThreadLocal,value 为存储的值。由此可以知道每个线程是通过借助哈希表来存储线程独立的值。

在深入源码分析之前,我简单介绍下自己对 ThreadLocal 的理解:ThreadLocal 实际只是一个工具类,当线程第一次调用 ThreadLocal 的 set/get 方法时,会为线程的 threadLocals 字段初始化一个 ThreadLocalMap 对象。当线程调用 set 方法时,实际是将 value 放入调用线程的 threadLocals 中。当线程调用 get 方法时,实际是从调用线程的 threadLocals 中获取。如果调用线程一直不终止,那么这个本地变量会一直存放在调用线程的 threadLocals 中,所以当不再需要本地变量时最好通过调用 remove 方法从调用线程的 threadLocals 中删除该本地变量。

上面是 ThreadLocal 的基本原理,要想更深入了解 ThreadLocal 内部实现,还是要回归源码。接下来会分成 ThreadLocal 和 ThreadLocalMap 两个部分来介绍,这里我选择优先介绍 ThreadLocalMap,它是 ThreadLocal 中定制的一个 map,ThreadLocal 的方法其实只是一些简单的工具壳,理解了 ThreadLocalMap 的实现,工具壳实际上就不难理解了。

3.1 ThreadLocalMap 源码分析

下面对 ThreadLocalMap 的字段和方法都进行讲解,一些基本字段上的描述我就直接贴在代码中了:

static class ThreadLocalMap {
    static class Entry extends WeakReference<ThreadLocal<?>> {
        Object value;
        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }
    /** Entry 数组的初始化大小 */
    private static final int INITIAL_CAPACITY = 16;    
    /** Entry 数组 */
    private Entry[] table;
    /** 记录 map 中实际存在的 entry 个数 */
    private int size = 0;
    /** 扩容的阈值,当 size 达到阈值时需要 resize 整个 map */
    private int threshold; 

    /** 根据长度计算扩容的阈值 = len * 2 / 3 */
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }

    /** 获取下一个索引,超出长度则返回 0。可以看出 entry 数组实际上是一个环形结构 */
    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }
    /** 返回上一个索引,如果 -1 为负数,返回长度 -1 的索引。可以看出 entry 数组实际上是一个环形结构 */
    private static int prevIndex(int i, int len) {
        return ((i - 1 >= 0) ? i - 1 : len - 1);
    }

    /** 构造函数1: 同时初始化第一个 Entry 值 */
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        // 初始容量 16
        table = new Entry[INITIAL_CAPACITY];
        // 通过 threadLocal 的 hashcode & (table.length - 1) 的计算索引,确定键值对的位置
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        // 创建一个新节点保存在 table 当中
        table[i] = new Entry(firstKey, firstValue);
        // 初始化 size
        size = 1;
        // 初始化扩容阀值
        setThreshold(INITIAL_CAPACITY);
    }

    /** ThreadLocal 本身是线程隔离的,按道理是不会出现数据共享和传递的行为的,这是 InheritableThreadLocal 提供了了一种父子间数据共享的机制。具体实现原理在第 5 节进行介绍 */
    private ThreadLocalMap(ThreadLocalMap parentMap) {
        Entry[] parentTable = parentMap.table;
        int len = parentTable.length;
        setThreshold(len);
        table = new Entry[len];

        for (int j = 0; j < len; j++) {
            Entry e = parentTable[j];
            if (e != null) {
                @SuppressWarnings("unchecked")
                ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                if (key != null) {
                    Object value = key.childValue(e.value);
                    Entry c = new Entry(key, value);
                    int h = key.threadLocalHashCode & (len - 1);
                    while (table[h] != null)
                        h = nextIndex(h, len);
                    table[h] = c;
                    size++;
                }
            }
        }
    }
    // 裁剪掉留到后面继续讲 ...
复制代码

上面列举了 ThreadLocalMap 中定义的一些字段和构造函数,这里最受关心的应该是 Entry 结构问题:

  1. Entry 为何声明为 WeakReference(弱引用)?
  2. ThreadLocalMap 的 Entry 并没有使用链表结构,那是如何解决散列冲突的?

3.1.1 Entry 为何声明为 WeakReference(弱引用)?

试想一下,如果不声明为 WeakReference,则采用强引用的方式。在线程运行过程中,如果因为不再使用某个 ThreadLocal 而将其置为空,但由于该 ThreadLocal 在线程的 threadLocals 中还有引用,会导致其无法被 GC 回收(除非线程运行结束),进而可能产生内存泄漏。而当把 Entry 声明为 WeakReference 后,当某个 ThreadLocal 被置为空时,线程中 threadLocalMap 中的 key 就不算强引用了,该 ThreadLocal 就可以被 GC 回收了。同时在 threadLocalMap 内部的很多方法中,也会逐渐把对于的 stale entry 清理掉,避免内存泄漏。

3.1.2 ThreadLocalMap 的 Entry 并没有使用链表结构,那是如何解决散列冲突的?

实际上解决 hash 冲突通常惯用开放地址法和链地址法,我们熟知的 HashMap 便是使用链表法来处理 hash 冲突,而 ThreadLocalMap 则是使用开放地址法来处理 hash 冲突的,这两种解决方法的优缺点如下:

  • 开放地址法:容易产生堆积问题;不适于大规模的数据存储;散列函数的设计对冲突会有很大的影响;插入时可能会出现多次冲突的现象,删除的元素是多个冲突元素中的一个,需要对后面的元素作处理,实现较复杂;结点规模很大时会浪费很多空间;
  • 链地址法:处理冲突简单,且无堆积现象,平均查找长度短;链表中的结点是动态申请的,适合构造表不能确定长度的情况;相对而言,拉链法的指针域可以忽略不计,因此较开放地址法更加节省空间。插入结点应该在链首,删除结点比较方便,只需调整指针而不需要对其他冲突元素作调整;

开放地址法不会创建链表,当关键字散列到的数组单元已经被另外一个关键字占用的时候,就会尝试在数组中寻找其他的单元,直到找到一个空的单元。探测数组空单元的方式有很多,ThreadLocalMap 采用了最简单的线性探测法。线性探测法就是从冲突的数组单元开始,依次往后搜索空单元,如果到数组尾部,再从头开始搜索(环形查找),其公式是 fi(key)=(f(key)+di) MOD m (di=1,2,3,⋯⋯,m−1)fi(key)=(f(key)+di) MOD m (di=1,2,3,⋯⋯,m−1)

为了解决散列表的冲突 ThreadLocal 引入了神奇的 hash code: 0x61c88647 ,具体可以参考 《ThreadLocal 和神奇的数字 0x61c88647》 这篇文章,但在实际应用中,无论如何构造哈希函数,冲突都是无法避免的,因此下面会介绍 ThreadLocalMap 的其他方法,来看看 ThreadLocalMap 具体是如何解决这些问题的。

3.1.3 getEntry 和 getEntryAfterMiss 方法

通过方法名就可以知道这是从 ThreadLocalMap 中获取 Entry 节点的方法,我们来看下实现细节:

private Entry getEntry(ThreadLocal<?> key) {
    // 1 通过 hashcode % (table.length - 1) 确定下标位置
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    // 2 如果 key 相同,则将 entry 则直接返回
    if (e != null && e.get() == key)
        return e;
    else
        // 3 如果找不到的话需要接着从 i 位置开始向后遍历,基于线性探测法,是有可能在 i 之后找到对应的 entry 的
        return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
    // 1 如果 entry 不为空,说明有 hash 碰撞的情况,需循环向后遍历
    while (e != null) {
        // 2 获取节点对应的 key
        ThreadLocal<?> k = e.get();
        // 3 如果 key 相等则返回 entry
        if (k == key)
            return e;
        // 4 如果 key 为 null,触发一次连续段清理
        if (k == null)
            expungeStaleEntry(i);
        // 5 获取下一个下标位置接着进行判断
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}
复制代码

因为 ThreadLocalMap 使用了开放地址法,即有冲突后,会把要插入的元素放到插入位置后面的 null 的地方。因此第一次计算出来的下标位置不一定存在与我们想要的 key 恰好相等的 Entry,所以采用了线性探测法向后查找是否被插入到了后面。可以关注到,当进入 getEntryAfterMiss 进行线性探测时,如果获取的 key 为 null 则会触发 expungeStaleEntry 方法进行一次连续清理(这个方法在 ThreadLocalMap 中被大量应用)。那么为什么需要这样的清理呢?因为对于线程来说,隔离的本地变量使用的是 WeakReference,那么便有可能在 GC 的时候就被回收了,如果很多 Entry 节点已经被回收了,但 table 数组中还留着位置,这是不清理就很浪费资源。在清理节点的同时,可以将后续非空的 Entry 节点重写计算下标进行排放,这样在 get 的时候就能很快速定位资源,加快效率。

3.1.4 expungeStaleEntry 方法

下面我们深入 expungeStaleEntry 源码看看是怎么清理的:

/** 连续段清理 */
private int expungeStaleEntry(int staleSlot) {
    // 1 新开一个引用指向 table
    Entry[] tab = table;
    int len = tab.length;

    // 2 先将传过来已经被回收的下标置为 null,将 table 的 size - 1
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    Entry e;
    int i;
    // 3 遍历删除指定节点的所有后续节点中, ThreadLocal 被回收的节点
    遍历删除指定节点所有后续节点当中,ThreadLocal 被回收的节点
    for (i = nextIndex(staleSlot, len);
            (e = tab[i]) != null;
            i = nextIndex(i, len)) {
        // 4 获取 entry 当中的 key
        ThreadLocal<?> k = e.get();
        // 5 如果 entry 当中的 key 为空,则将 value 以及数组下标所在的位置设置为空,size - 1
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 6 如果不为空,则重新计算 key 的下标
            int h = k.threadLocalHashCode & (len - 1);
            // 7 如果是当前位置则直接遍历下一个,如果不是当前位置,则重新从 i 开始找下一个为 null 的坐标进行赋值
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}
复制代码

显然这个方法实际上就是从 staleSlot 开始做一个连续段的清理和 rehash 操作。

3.1.5 set 方法

介绍完从 ThreadLocalMap 中获取 Entry 节点的方法,接下来再看下 set 方法,自然我们要将我们的变量保存进 ThreadLocalMap 中,我们还是从源码进行分析:

private void set(ThreadLocal<?> key, Object value) {
    // 1 新开一个引用指向 table
    Entry[] tab = table;
    // 2 获取table的长度
    int len = tab.length;
    // 3 通过 hashcode % (table.length - 1) 确定下标位置
    int i = key.threadLocalHashCode & (len-1);

    // 4 从该下标开始循环遍历
    for (Entry e = tab[i];
            e != null;
            e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        // 4.1 如果遇到相同的 key,则直接替换 value
        if (k == key) {
            e.value = value;
            return;
        }
        // 4.2 如果 key 为 null,则说明当前 key 作为弱引用被 GC 了,此时旧数据需要被清理
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 5 遍历到了数组为 null 的位置,进行赋值
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 6 调用 cleanSomeSlots 尝试性发现并清理失效 entry,如果没有发现且当前容量超过阈值,则调用 rehash
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}
复制代码

ThreadLocalMap 的 set 方法实际还是很复杂的,通过这个方法,我们可以看到该哈希表是如何通过线性探测法来解决冲突的。先来看下 replaceStaleEntry 方法:

3.1.6 replaceStaleEntry 方法

replaceStaleEntry 方法是当我们线性探测时,如果碰到 key 为 null 的元素时调用的。这个方法的功能可以理解为,我们在 staleSlot 位置发现 key 为 null 的元素,将新值覆盖到 staleSlot 位置上并清理 staleSlot 附近(即 staleSlot 位置前后连续的非 null 过期元素) key 为 null 的元素:

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // 1 记录最前的失效节点下标
    int slotToExpunge = staleSlot;
    // 2 向前遍历失效的 entry 节点,如果找到,就把节点下标赋给 slotToExpunge
    for (int i = prevIndex(staleSlot, len);
            (e = tab[i]) != null;
            i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;
    // 3 从 staleSlot 向后遍历
    for (int i = nextIndex(staleSlot, len);
            (e = tab[i]) != null;
            i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // 4 因为我们是在 set 方法一碰到脏数据就调用了次方法,但实际可能 key 在后置位置是存在的,则可能进入这段代码
        if (k == key) {
            e.value = value;
             // 4.1 这个时候会将找到的 entry 与失效节点进行交换,以维护哈希表顺序
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;


            // 4.2 如果之前没有找到过失效节点,则只有 i 位置是失效节点
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;

            // 4.3 先进行连续段清理,再调用 cleanSomeSlots 进行启发性清理,它的时间复杂度 O(log2n)
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // 5 如果向前(或向后)没有找到 过 失效节点,且当前下标是失效节点,则替换当前节点为失效节点
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 6 如果当前 table 没有找到该 key,则在 staleSlot 这个位置重新创建一个 Entry
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 7 如果在之前的循环探测中,发现过无效的 entry 节点,即 slotToExpunge 被重新赋值过,就会触发连续段清理和启发式清理
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
复制代码

3.1.7 cleanSomeSlots 方法

前面已经介绍过 expungeStaleEntry 方法,它会从 staleSlot 开始做一个连续段的清理和 rehash 操作,而 cleanSomeSlots 方法则是进行启发式清理,它的执行复杂度log2(n),该方法只是尝试性地寻找一些失效 entry 若,有发现无效 entry 返回 true。

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        // 1 如果该节点已经失效
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            // 2 调用该方法进行连续性回收和 rehash
            i = expungeStaleEntry(i);
        }
    // n >>>= 1 无符号右移 1 位,即移动次数以 n 的二进制最高位的 1 的位置为基准,所以时间复杂度 log2(n)
    } while ( (n >>>= 1) != 0);
    return removed;
}
复制代码

在新增或替换元素成功后,为了尽可能少地在 get/set 时发现有旧元素的情况,会在清理失效 entry 后多次调用 cleanSomeSlots 尝试性地发现并清理一些旧元素,为了执行效率,“cleanSome” 是 “no clean” 不清理和 “clean all” 全量清理之间一的种平衡。

3.1.8 rehash / expungeStaleEntries / resize 方法

在成功 set 值之后,通过阀值判断,如果表空间不足就会调用 rehash 方法,该方法首先全量遍历清除失效节点,然后再清理后判断容量是否足够,如果不够则进行 2 倍扩容并重新散列。其中 expungeStaleEntries 则是全量清理旧元素,resize 则是二倍扩容。

/** rehash全量地遍历清理旧元素,然后判断容量若大于阈值的3/4,则扩容并从新散列 */
private void rehash() {
    // 1 全量遍历清理旧元素
    expungeStaleEntries();

    // 2 适当的扩容,以避免hash散列到数组时过多的位置冲突
    if (size >= threshold - threshold / 4)
        // 2.1 2倍扩容并重新散列
        resize();
}

/** 二倍扩容 */
private void resize() {
    // 1 获取旧 table 的长度,并且创建一个长度为旧长度 2 倍的 Entry 数组
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    // 2 记录插入的有效 Entry 节点数
    int count = 0;

    // 3 从下标 0 开始,逐个向后遍历插入到新的 table 当中
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            // 4 如遇到 key 已经为 null,则 value 设置 null,方便 GC 回收
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                // 5 通过 hashcode & len - 1 计算下标,如果该位置已经有 Entry 数组,则通过线性探测向后探测插入
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }
    // 6 重新设置扩容的阈值
    setThreshold(newLen);
    // 7 更新size
    size = count;
    // 8 指向新的 Entry 数组
    table = newTab;
}

/** 全量遍历清理旧元素 */
private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}
}
复制代码

3.1.9 remove 方法

​既然是 Map 形式进行存储,当然离不开 remove 方法:

/** 将 ThreadLocal 对象对应的 Entry 节点从 table 当中删除 */
private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
            e != null;
            e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            // 将引用设置 null,方便 GC
            e.clear();
            // 从该位置开始进行一次连续段清理
            expungeStaleEntry(i);
            return;
        }
    }
}
复制代码

可以看到,remove 节点的时候,也会使用线性探测的方式,当找到对应 key 的时候,就会调用 clear 将引用指向 null,并且会触发一次连续段清理。

3.2 ThreadLocal 源码分析

如果能理解 ThreadLocalMap 的结构,那么理解 ThreadLocal 就不是什么难题了,在我看来,ThreadLocal 只是封装了对 ThreadLocalMap 操作的工具方法,下面再逐个介绍 ThreadLocal 具体的几个方法实现。

3.2.1 get 方法

可以看到 get 方法的执行流程是:先获取当前 Thread 对象,再通过 getMap 获取当前线程的 threadLocals 变量,如果 threadLocals 不为空,则以当前 threadLocal 为 key,获取 entry 对象,然后从 entry 中取出 value。如果 threadLocals 为空,则调用 setInitialValue进行初始化。

public T get() {
    // 1 获取当前线程
    Thread t = Thread.currentThread();
    // 2 获取当前线程的 threadLocals 变量
    ThreadLocalMap map = getMap(t);
    // 3 如果 threadLocals 不为空,则返回对应本地变量的值
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // 4 如果 threadLocals 为空则初始化当前线程的 threadLocals 成员变量
    return setInitialValue();
}

// 可以看到 getMap 方法返回的就是当前 Thread 对象的 threadLocals 变量。
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

// setInitialValue 方法是在 threadLocals 为空的时候调用的
// 它先会调用 initialValue 方法生成一个 null 值
// 然后会校验以下 threadLocals 是否为空,如果不为空,直接设置 value 值,
// 确实不存在才调用 createMap 方法创建当前线程的 threadLocals 变量。
private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}
// 这里调用了 ThreadLocalMap 的构造方法,为 threadLocals 字段初始化 map 对象
void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

复制代码

3.2.1 set 方法

可以看到 set 方法中会先获取当前调用线程,再通过 getMap 方法获取调用线程的 threadLocals 变量。如果 threadLocals 不为空,则把值设置到 threadLocals 中。如果 threadLocals 为空,则创建当前线程的 threadLocals 变量。

public void set(T value) {
    // 1 获取当前线程
    Thread t = Thread.currentThread();
    // 2 将当前线程作为 key,去查找对应的线程变量,找到则设置,如果是第一次调用就创建当前线程对应的 hashMap
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

// 相关源码:
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}
复制代码

4 ThreadLocal 不支持继承性

先来看一段代码:

public class TestThreadLocal {

    // 1 创建线程变量
    public static ThreadLocal<String> threadLocal1 = new ThreadLocal<>();

    public static void main(String[] args) {
        // 2 设置线程变量
        threadLocal1.set("hello world");
        // 3 启动子线程
        Thread thread = new Thread(new Runnable() {
            @Override
            public void run() {
                // 4 子线程输出线程变量的值
                System.out.println("thread: " + threadLocal1.get());
            }
        });
        thread.start();
        // 5 主线程输出线程变量的值
        System.out.println("main: " + threadLocal1.get());
    }
}
复制代码

运行结果如下:

main: hello world
thread: null
复制代码

可以看到同一个 ThreadLocal 变量在父线程被设置值后,在子线程中是获取不到的,这属于正常现象,如有需要,可以通过后面介绍的 InheritableThreadLocal 来解决。

5 InheritableThreadLocal 类

5.1 InheritableThreadLocal 作用和原理

为了解决 threadLocal 不支持继承性的问题,即需要让子线程能访问到父线程中的值,就需要使用到 InheritableThreadLocal,下面是它的源码:

package java.lang;
import java.lang.ref.*;

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

复制代码

可以看到,InheritableThreadLocal 继承了 ThreadLocal,并重写了 3 个方法,craeteMap 方法创建的是当前线程的 inheritableThreadLocals 变量的实例,而不是 threadLoacals。调用 get 方法获取当前线程内部的 map 变量时,获取的是 inheritableThreadLocals 而不是 threadLoacals。所以要弄清 InheritableThreadLocal 如何让子线程可以访问父线程的本地变量,需要看 childValue 方法,在这之前我们先从 Thread 的构造说起:

private void init(ThreadGroup g, Runnable target, String name,
                    long stackSize, AccessControlContext acc,
                    boolean inheritThreadLocals) {
    if (name == null) {
        throw new NullPointerException("name cannot be null");
    }
    // 1 指定线程名称
    this.name = name;

    // 2 获取当前线程
    Thread parent = currentThread();
    // 3 获取系统安全服务
    SecurityManager security = System.getSecurityManager();
    // 4 如果没有线程组则设置线程组,如果为 null 则通过安全服务获取线程组,如果获取后的线程组仍为 null,则获取当前线程的线程组
    if (g == null) {
        if (security != null) {
            g = security.getThreadGroup();
        }
        if (g == null) {
            g = parent.getThreadGroup();
        }
    }
    // 5 检查线程组
    g.checkAccess();
    // 6 权限检查
    if (security != null) {
        if (isCCLOverridden(getClass())) {
            security.checkPermission(SUBCLASS_IMPLEMENTATION_PERMISSION);
        }
    }
    // 7 往线程组添加线程
    g.addUnstarted();

    this.group = g;
    this.daemon = parent.isDaemon();
    this.priority = parent.getPriority();
    if (security == null || isCCLOverridden(parent.getClass()))
        this.contextClassLoader = parent.getContextClassLoader();
    else
        this.contextClassLoader = parent.contextClassLoader;
    this.inheritedAccessControlContext =
            acc != null ? acc : AccessController.getContext();
    this.target = target;
    setPriority(priority);
    // 8 如果父线程的 inheritableThreadLocals 不为空
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
    // 9 设置子线程中 inheritableThreadLocals 变量
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    this.stackSize = stackSize;
    tid = nextThreadID();
}
复制代码

在线程创建时会调用 init 方法,该方法会获取当前线程(指的时是父线程),然后判断父线程的 inheritableThreadLocals 变量是否为 null,如果为 null 就调用 createInheritedMap 创建一个新的 ThreadLocalMap 变量然后赋值子线程的 inheritableThreadLocals 变量:

static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
    return new ThreadLocalMap(parentMap);
}

private ThreadLocalMap(ThreadLocalMap parentMap) {
    Entry[] parentTable = parentMap.table;
    int len = parentTable.length;
    setThreshold(len);
    table = new Entry[len];

    for (int j = 0; j < len; j++) {
        Entry e = parentTable[j];
        if (e != null) {
            @SuppressWarnings("unchecked")
            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
            if (key != null) {
                Object value = key.childValue(e.value);
                Entry c = new Entry(key, value);
                int h = key.threadLocalHashCode & (len - 1);
                while (table[h] != null)
                    h = nextIndex(h, len);
                table[h] = c;
                size++;
            }
        }
    }
}
复制代码

可以看到该构造函数内部把父线程的 inheritableThreadLocals 成员变量的值复制到新的 ThreadLocalMap 中(通过调用 inheritableThreadLocals 类重写的 childValue 方法)。

所以 inheritableThreadLocals 的原理是通过重写 getMap 和 createMap 方法让本地变量保存到了具体线程的 inheritableThreadLocals 变量中,当父线程创建子函线程时,构造函数就会把父线程的 inheritableThreadLocals 变量中的本地变量复制一份保存到子线程的 inheritableThreadLocals 变量中。

5.2 InheritableThreadLocal 基本使用

public class TestThreadLocal2 {

    // 1 创建线程变量
    public static ThreadLocal<String> threadLocal1 = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        threadLocal1.set("hello world");
        Thread thread = new Thread(new Runnable() {
            @Override
            public void run() {
                System.out.println("thread: " + threadLocal1.get());
            }
        });
        thread.start();
        System.out.println("main: " + threadLocal1.get());
    }
}
复制代码

运行结果如下:

main: hello world
thread: hello world
复制代码

5.3 InheritableThreadLocal 使用场景

InheritableThreadLocal 并不是不可替代的,例如可以在父线程中构造一个 map 作为参数传递给子线程中。这种子线程需要获取父线程的 threadLocal 变量的场景还是蛮多的,例如子线程需要使用父线程 threadLocal 变量中的用户登录信息,再比如一些中间件需要把统一的 id 追踪的整个调用链路记录下来。

原文  https://juejin.im/post/5ead00d55188256d8136b312
正文到此结束
Loading...