转载

ThreadLocal源码分析

我们Threadlocal类的作用是提供一个线程间隔离,线程内部共享的数据。今天我们一起看看TreadLocal是怎么做到线程隔离的。

例子

例子同样可以在 github 中找到

public static void testThreadLocal() {
    ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
    System.out.println(Thread.currentThread().getName() + ".set: " + -1);
    threadLocal.set(-1);

    ExecutorService executorService = Executors.newCachedThreadPool();
    for (int i=1; i< 5; i++) {
        final Integer setValue = i;
        executorService.submit(() -> {
            System.out.println(Thread.currentThread().getName() + ".set: " + setValue);

            threadLocal.set(setValue);
            System.out.println(Thread.currentThread().getName() + ".get: " + threadLocal.get());
            threadLocal.remove();
        });
    }
    System.out.println(Thread.currentThread().getName() + ".get: " +  threadLocal.get());
    threadLocal.remove();
}
复制代码

运行结果:

main.set: -1
pool-1-thread-1.set: 1
pool-1-thread-2.set: 2
pool-1-thread-2.get: 2
pool-1-thread-3.set: 3
pool-1-thread-1.get: 1
pool-1-thread-3.get: 3
pool-1-thread-4.set: 4
main.get: -1
pool-1-thread-4.get: 4
复制代码

代码中threadLocal对象看着也是被多线程竞争写入的,多个线程同时对他进行写入,但每个线程get到的都是正确的结果,为什么可以做到线程隔离呢?

源码

我们先大致看看set方法

public void set(T value) {
    //得到当前线程
    Thread t = Thread.currentThread();
    //获取线程的ThreadLocalMap属性
    ThreadLocalMap map = getMap(t);
    //map不为空时,set threadlocal 和value
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value); //为空时创建一个map并将threadlocal 和value放入
}
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}
复制代码

原来threadLocal set value的时候,首先获得当前的线程对象,然后得到线程对象的ThreadLocalMap属性,然后将 threadlocal自身作为key , set到map中。图解一下Thread类和ThreadLocal类的关系。

ThreadLocal源码分析

原来Thread对象中有个ThreadLocalMap属性,ThreadLocalMap顾名思义就是存放ThreadLocal的map。所以虽然例子中看着threadLocal是竞争的写入,其实不是,都是在自己的线程对象中维护了一个threadLocal。

get方法

也清晰了,就是从Thread对象里拿key为这个threadLocal对象的 value值呗!

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        //map中当前threadlocal作为key,拿到value的值,并返回
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}
复制代码

看到这ThreadLocal类的原理就说完了。

但Doug Lea 哪是一般人,源码中围绕着减少内存泄漏做的很多努力。下面我们就看看为什么会发生内存泄漏,以及怎么防止内存泄漏。

名词解释:

什么是内存泄漏?

  • 无用对象(不再使用的对象)持续占有内存或无用对象的内存得不到及时释放,从而造成内存空间的浪费称就叫做内存泄漏。

为什么会发生内存泄漏?

  • new出来的ThreadLocal对象有两个地方引用,threadLocal变量和线程属性中的threadLocalMap中的key,如果将threadLocal变量赋值为空后,因为线程的成员变量和线程生命周期相同,垃圾回收器仍然不能回收,造成了内存泄漏。
  • 同上即使key指向的threadlocal对象被垃圾回收了,value指向的对象仍然存活着,还是有内存泄漏。

怎么解决?

  • 防止threadLocal对象的内存泄漏,使用弱引用。
  • 防止value对象的内存泄漏,使用过期检查和清理,以及提供remove方法(后面会详细介绍)。

强软弱虚四种引用的定义及使用场景

  • 强引用: 普通的引用,引用存在时,垃圾回收器不能回收。我们new出来的对象都是强引用的。
  • 软引用: 垃圾回收后,内存还是不够,进行回收,内存够用是不会回收的(不干掉你我JVM就内存溢出了)。 软引用适合做缓存
  • 弱引用: 只有弱引用指向对象时,垃圾回收时就会回收。threadLocal中用于防止内存泄漏。
  • 虚引用: 有没有垃圾回收都get不到引用,用于管理直接内存,对象回收时,放入指定队列中,垃圾回收器额外处理指向的直接内存。
  • 例子可以在 github 中查看
    ThreadLocal源码分析
    上面是方法运行时,栈中内存和堆中内存的示例图,方便我们理解。

为什么弱引用可以帮我们解决key上的内存泄漏呢?

  • 根据弱引用的定义,上图中当threadlocal变量指向threadLocal对象的强引用被干掉时(即threadlocal=null),只有map中的key弱弱的指向它,垃圾回收器看它没用了,立马回收掉。这就解决了threadlocal对象的内存泄漏。下面源码看看实现吧
static class ThreadLocalMap {
    //这里的源码可以看到map中Entry类继承了WeakReference类,key弱弱的引用ThreadLocal对象
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }
    。。。
}
复制代码

上面说了防止value对象的内存泄漏,使用过期检查和清理,以及提供remove方法,这里是ThreadLocal最复杂的一部分,我们详细看看吧。再看set方法。

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}
复制代码

调用map的set方法,并不是常用的put方法,看来有不是简单的存值啊

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    //根据hashcode和Entry数组的长度,计算下标值
    int i = key.threadLocalHashCode & (len-1);

    //根据得到的下标值找,遇到hash冲突就向后移动一个,直到找到entry是空的节点
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        //遇到key相等,说明key之前存过,替换value值就行了
        if (k == key) {
            e.value = value;
            return;
        }

        //如果k是空,说明这里存的是一个过期数据,进行替换
        //这里会进行过期数据的清理
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    //前面的位置都被占用着,新建一个Entry放在i上
    tab[i] = new Entry(key, value);
    //将map中的size加1
    int sz = ++size;
    //扫描清理一次过期数据,如果还是达到扩容的阈值了,进行扩容
    //这里也会进行过期数据的清理
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();//先进行一次全量的扫描清理过期数据,还是快接近阈值就扩容
}
复制代码

replaceStaleEntry方法

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // 向前扫描第一个过期的节点
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i; //标识第一个需要清除的位置

    // 向后遍历
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // 向后找到了key,把value进行替换
        if (k == key) {
            e.value = value;
            //i节点设为过期数据
            tab[i] = tab[staleSlot];
            //之前的过期节点赋值为key的Entry数据
            tab[staleSlot] = e;

            // 如果staleSlot就是第一个过期数据(上面的for进行了一次向前扫描),把过期下标设为i
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            //expungeStaleEntry方法清理过期节点,并进行整理(因为存在hash冲突后移,可能某些节点的hash位置空出来了,放入对应的自己的位置,后面会有图解说明)
            //cleanSomeSlots会清理Log n次,为了效率不能每次都全量扫描   
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // staleSlot是第一个过期数据,把slotToExpunge标记为i 说明有其他过期节点
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 过期位置赋值为用key value构建的新Entry
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 如果slotToExpunge != staleSlot说明有其他节点也过期了,继续清理一些其他过期节点
    //和for循环中slotToExpunge = i 呼应
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
复制代码

replaceStaleEntry方法顾名思义用当前的key value构造一个entry替换这个过期的Entry节点。但因为存在hash冲突后移,并不能单纯的直接替换,所以做了上面的这么多事情

//清理下标为staleSlot的过期节点
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 过期节点设为空
    tab[staleSlot].value = null; //help gc
    tab[staleSlot] = null;
    size--;

    // 清理的过程中可能之前因为存在hash冲突后移的节点,位置恰好是staleSlot,staleSlot空出来了,节点应该放在正确的位置。
    Entry e;
    int i;
    //向后扫描
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        //节点key为空,说明已过期,直接干掉
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            //计算节点的hash值,确定在数组中的位置
            int h = k.threadLocalHashCode & (len - 1);
            //如果节点不应该放在i位置上,则可能放在h到i中间的位置上
            if (h != i) {
                tab[i] = null;

                // 从h位置一直后移,找到第一个为空的位置,放在正确的位置上(hash冲突后移的逻辑)
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}
复制代码

cleanSomeSlots 清理部分过期Entry

//进行log n次扫描
    //如果没有发现过期节点返回false(没有节点移动)
    //如果发现了过期节点,清理过期节点,n重置为table数组的length,再次扫描log n次
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}
复制代码

上面是set方法中对防止内存泄漏的一些努力,每次set都会对一些过期节点进行清除整理,这一部分也是较难理解的。我们放一张图,方便大家理解。

ThreadLocal源码分析
ThreadLocal源码分析

我们看看get方法,会发现也对防止内存泄漏做了一些努力

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    //当map为空时,创建一个map并存入key:this value:null  返回null
    return setInitialValue();
}
private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        //这里是重点,当hash的位置被其他节点占用了,可能是冲突后移了,可能就是没有
        return getEntryAfterMiss(key, i, e);
}
复制代码

getEntryAfterMiss 方法

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
    //向后找,直到找到entry节点是空时,返回null
    while (e != null) {
        ThreadLocal<?> k = e.get();
        //k正好是我们要找的数据,返回节点entry
        if (k == key)
            return e;
        //如果k是空,说明是过期节点,清除该过期节点    
        if (k == null)
            expungeStaleEntry(i); 
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}
复制代码

remove方法

//手动清理threadLocal
private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        //得到key对应的entry 
        if (e.get() == key) {
            e.clear(); //将referent赋值为null
            expungeStaleEntry(i); //清理该节点
            return;
        }
    }
}
复制代码

你可能有疑问,既然set和get方法都会移除过期节点,还要我们remove吗?

强烈建议大家使用完threadlocal后一定要调用remove方法。

填坑记

我们曾经一个项目中使用了threadlocal,业务上是这样的

  • 根据参数得到一个商家类型,如果是类型A就把A放入threadlocal中,如果是类型B就不放。
  • threadlocal是线程间隔离,线程中共享的嘛,后面的代码就可以根据threadlocal.get判断商家类型了。
  • 快上线做测试回归的时候,发现总是有概率商家类型会判断错误,一群人加班,后来发现是threadlocal用完后,没有调用remove方法。你知道为什么会这样吗?
  • 因为tomcat线程池,线程是重用的,如果线程t1上次使用是被放了A进去,因为t1没有销毁,下次访问A还在里面,即使这次商家类型是B,但B没有重写进去,调用thread.get 得到的仍然是A。 所以再次建议大家使用完threadlocal后,一定要进行remove
原文  https://juejin.im/post/5ef1abf6f265da02b6432025
正文到此结束
Loading...