ConcurrentHashMap源码.md

Image by Lou Kelly from Pixabay

前言

正文

jdk版本:1.8.0_181

数据结构

数组,链表 红黑树;数据结构和HashMap数据结构一样;

构造方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

    /**
     * Creates a new, empty map with the default initial table size (16).
     */
    // 无参构造方法
    public ConcurrentHashMap() {
    }
    
    // 初始化指定容量的构造方法
    public ConcurrentHashMap(int initialCapacity) {
        if (initialCapacity < 0)
            throw new IllegalArgumentException();
        int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
                   MAXIMUM_CAPACITY :
                   tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
        this.sizeCtl = cap;
    }

常用的构造方法主要就是上面两种;
第一种:什么都没做,所以会使用默认的配置,默认配置就是数组长度为16
第二种:指定容器数组长度,这里和HashMap的指定容器长度的构造方法差不多,都是调用了tableSizeFor方法, 不同的是HashMap直接将传入的值作为参数去调用tableSizeFor方法,而ConcurrentHashMap将传入的值进行了 initialCapacity + (initialCapacity >>> 1) + 1,也就是在原来的基础上大概增加了一半

添加方法

提供对外调用的添加方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

    public V put(K key, V value) {
        return putVal(key, value, false);
    }

    public void putAll(Map<? extends K, ? extends V> m) {
        tryPresize(m.size());
        for (Map.Entry<? extends K, ? extends V> e : m.entrySet())
            putVal(e.getKey(), e.getValue(), false);
    }

    public V putIfAbsent(K key, V value) {
        return putVal(key, value, true);
    }

第一个put方法,这个方法是最常用的; 第二个是将Map中的数据添加进ConcurrentHashMap,实际使用for循环Map,再put
第三个也是添加,但是是在没有key值的情况下才会添加;这种适合设置默认值的时候用;

上面三个方法都调用了putVal方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

final V putVal(K key, V value, boolean onlyIfAbsent) {
        // 判断空
        if (key == null || value == null) throw new NullPointerException();
        // 计算hash值
        int hash = spread(key.hashCode());
        int binCount = 0; 
        // 循环(cas)
        for (Node<K,V>[] tab = table;;) {
            Node<K,V> f; int n, i, fh;
            // 判断是否已经初始化
            if (tab == null || (n = tab.length) == 0)
                tab = initTable();
            // 判断对应(桶)数组位置是否有数据
            else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
                // cas将数据添加至(桶)数组对应头结点(线程安全)
                if (casTabAt(tab, i, null,new Node<K,V>(hash, key, value, null)))
                    break;                   // no lock when adding to empty bin
            }
            // 判断当前(桶)数组对应位置是否为MOVED,判断当前Map是否在扩容
            else if ((fh = f.hash) == MOVED)
                // 加速扩容
                tab = helpTransfer(tab, f);
            else {
                V oldVal = null;
                // 使用synchronized将对应位置锁住(线程安全)
                synchronized (f) {
                    if (tabAt(tab, i) == f) {
                        if (fh >= 0) {
                            binCount = 1;
                            // 链表处理,和HashMap差不多
                            for (Node<K,V> e = f;; ++binCount) {
                                K ek;
                                // 判断是否有相同的key,找到就替换
                                if (e.hash == hash &&
                                    ((ek = e.key) == key ||
                                     (ek != null && key.equals(ek)))) {
                                    oldVal = e.val;
                                    if (!onlyIfAbsent)
                                        e.val = value;
                                    break;
                                }
                                Node<K,V> pred = e;
                                // 没有找到相同的key,就直接在末尾添加
                                if ((e = e.next) == null) {
                                    pred.next = new Node<K,V>(hash, key,value, null);
                                    break;
                                }
                            }
                        }
                        // 判断是否红黑树(也和HashMap差不多)
                        else if (f instanceof TreeBin) {
                            Node<K,V> p;
                            binCount = 2;
                            if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,value)) != null) {
                                oldVal = p.val;
                                if (!onlyIfAbsent)
                                    p.val = value;
                            }
                        }
                    }
                }
                if (binCount != 0) {
                    // 判断是否需要树化
                    if (binCount >= TREEIFY_THRESHOLD)
                        // 链表转红黑树
                        treeifyBin(tab, i);
                    if (oldVal != null)
                        return oldVal;
                    break;
                }
            }
        }
        // 修改容器中元素数量
        addCount(1L, binCount);
        return null;
    }

说明:

  1. synchronized锁住部分代码与HashMap基本差不多;但是处理流程和HashMap就不同了;
  2. 使用cassynchronized(这里是分段锁)来保证线程安全;
  3. 扩容在addCount方法中,这里需要注意binCount数值,在addCount方法中会用到;
  4. 上面这段代码中的MOVED状态,需要结合transfer方法来观察;

流程:

  1. 判断是否初始化
  2. 判断是否为第一个节点
  3. 判断是否扩容
  4. 添加进桶中 4.1 锁住 4.1.1 判断链表 4.1.1.1 判断是末尾添加还是覆盖 4.1.2 判断是否红黑树 4.1.2.1 红黑树添加 4.2 判断桶中元素数量 4.2.1 判断是否将链表转红黑树 4.2.2 判断是添加还是覆盖,添加继续执行,覆盖则返回
  5. 添加数量(addCount

初始化方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

    private final Node<K,V>[] initTable() {
        Node<K,V>[] tab; int sc;
        // 判断当前容器数组是否已经初始化
        while ((tab = table) == null || tab.length == 0) {
            // 判断sizeCtl数值,如果小于0,则代表有其他线程正在初始化
            if ((sc = sizeCtl) < 0)
                Thread.yield(); // lost initialization race; just spin
            // cas方式修改SIZECTL值
            else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
                try {
                    // 判断容器
                    if ((tab = table) == null || tab.length == 0) {
                        // 判断容器长度,没若有指定,则使用DEFAULT_CAPACITY;默认16长度就在这个地方实现
                        int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                        @SuppressWarnings("unchecked")
                        Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                        table = tab = nt;
                        // sc为数组容量3/4
                        sc = n - (n >>> 2);
                    }
                } finally {
                    //修改值
                    sizeCtl = sc;
                }
                // 停止while循环
                break;
            }
        }
        return tab;
    }

初始化也使用了cas来保证线程安全;也使用了双重校验数组长度是否为空;
还需要注意的是sizeCtl,当容器正在扩容时,sizeCtl是负数; 多线程竞争时,使用了Thread.yield();

添加容器元素数量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

    private final void addCount(long x, int check) {
        CounterCell[] as; long b, s;
        // 这里注意baseCount值,使用cas添加1,也就是目前容器中的元素数量
        if ((as = counterCells) != null ||
            !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
            CounterCell a; long v; int m;
            boolean uncontended = true;
            if (as == null || (m = as.length - 1) < 0 ||
                (a = as[ThreadLocalRandom.getProbe() & m]) == null ||
                !(uncontended =
                  U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
                 // 分段cas方法
                fullAddCount(x, uncontended);
                return;
            }
            if (check <= 1)
                return;
            s = sumCount();
        }
        // 主要看这里
        if (check >= 0) {
            Node<K,V>[] tab, nt; int n, sc;
            // 判断容器是否需要扩容
            while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
                   (n = tab.length) < MAXIMUM_CAPACITY) {
                int rs = resizeStamp(n);
                if (sc < 0) { // 说明已经有线程在进行扩容,加速扩容
                    if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                        sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                        transferIndex <= 0)
                        break;
                    if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))// 增加一个线程,帮助扩容
                        transfer(tab, nt);
                }
                // 进行扩容,修改sizeCtl值
                else if (U.compareAndSwapInt(this, SIZECTL, sc,(rs << RESIZE_STAMP_SHIFT) + 2))
                    transfer(tab, null);
                s = sumCount();
            }
        }
    }

这个方法主要是两个功能:
一是修改容器元素数量值,也就是cas修改baseCount,但是这里需要注意的是,如果没有cas成功,则表示多线程竞争添加,就会分段cas,这里就不细说了; 二是判断是否需要扩容,也就是调用transfer扩容;若正在扩容,则加速扩容;

总结起来就是:cas,分段cas;扩容,加速扩容

扩容

ConcurrentHashMap的扩容是新建一个新的数组,容量是原来数组的两倍,然后再将原数组中元素添加到新建的数组中;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
        int n = tab.length, stride;
        if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
            stride = MIN_TRANSFER_STRIDE; // subdivide range
        // 新建一个数组,大小为原来的两倍 体现在 n << 1
        if (nextTab == null) {            // initiating
            try {
                @SuppressWarnings("unchecked")
                Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
                nextTab = nt;
            } catch (Throwable ex) {      // try to cope with OOME
                sizeCtl = Integer.MAX_VALUE;
                return;
            }
            nextTable = nextTab;
            transferIndex = n;
        }
        int nextn = nextTab.length;
        // 注意点1,ForwardingNode 继承了Node,ForwardingNode默认hash为MOVED
        ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
        boolean advance = true;
        boolean finishing = false; // to ensure sweep before committing nextTab
        // 循环 遍历原数组,将元素放进新数组
        for (int i = 0, bound = 0;;) {
            Node<K,V> f; int fh;
            while (advance) {
                int nextIndex, nextBound;
                if (--i >= bound || finishing)
                    advance = false;
                else if ((nextIndex = transferIndex) <= 0) {
                    i = -1;
                    advance = false;
                }
                else if (U.compareAndSwapInt
                         (this, TRANSFERINDEX, nextIndex,
                          nextBound = (nextIndex > stride ?
                                       nextIndex - stride : 0))) {
                    bound = nextBound;
                    i = nextIndex - 1;
                    advance = false;
                }
            }
            // 判断原数组中元素是否全部移动完成
            if (i < 0 || i >= n || i + n >= nextn) {
                int sc;
                // 移动完成后
                if (finishing) {
                    nextTable = null;
                    // table指向新数组
                    table = nextTab;
                    // 修改sizeCtl值,为扩容后的3/4
                    sizeCtl = (n << 1) - (n >>> 1);
                    return;
                }
                if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
                    if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                        return;
                    finishing = advance = true;
                    i = n; // recheck before commit
                }
            }
            else if ((f = tabAt(tab, i)) == null) // 判断原数组对应i的位置是否有元素
                // 没有元素,就将fwd放进原数组i对应的位置,这里就代表原数组对应i位置已经移动过
                advance = casTabAt(tab, i, null, fwd); 
            else if ((fh = f.hash) == MOVED) // 判断原数组对应i的位置是否已经被移动过
                advance = true; // already processed
            else {
                // 锁定(这段代码块具体就是将原数组对应位置的元素放进新的数组中,主要关注setTabAt方法)
                synchronized (f) {
                    // 原数组一个位置的元素会转移至新数组中的两个位置中,
                    if (tabAt(tab, i) == f) {
                        Node<K,V> ln, hn;
                        if (fh >= 0) { // 链表转移至新数组
                            int runBit = fh & n;
                            Node<K,V> lastRun = f;
                            for (Node<K,V> p = f.next; p != null; p = p.next) {
                                int b = p.hash & n;
                                if (b != runBit) {
                                    runBit = b;
                                    lastRun = p;
                                }
                            }
                            if (runBit == 0) {
                                ln = lastRun;
                                hn = null;
                            }
                            else {
                                hn = lastRun;
                                ln = null;
                            }
                            for (Node<K,V> p = f; p != lastRun; p = p.next) {
                                int ph = p.hash; K pk = p.key; V pv = p.val;
                                if ((ph & n) == 0)
                                    ln = new Node<K,V>(ph, pk, pv, ln);
                                else
                                    hn = new Node<K,V>(ph, pk, pv, hn);
                            }
                            
                            setTabAt(nextTab, i, ln); // 新数组设置头结点
                            setTabAt(nextTab, i + n, hn); // 新数组设置头结点
                            setTabAt(tab, i, fwd); // 原数组设置移动标识
                            advance = true;
                        }
                        else if (f instanceof TreeBin) { // 红黑树转移至新数组
                            TreeBin<K,V> t = (TreeBin<K,V>)f;
                            TreeNode<K,V> lo = null, loTail = null;
                            TreeNode<K,V> hi = null, hiTail = null;
                            int lc = 0, hc = 0;
                            for (Node<K,V> e = t.first; e != null; e = e.next) {
                                int h = e.hash;
                                TreeNode<K,V> p = new TreeNode<K,V>
                                    (h, e.key, e.val, null, null);
                                if ((h & n) == 0) {
                                    if ((p.prev = loTail) == null)
                                        lo = p;
                                    else
                                        loTail.next = p;
                                    loTail = p;
                                    ++lc;
                                }
                                else {
                                    if ((p.prev = hiTail) == null)
                                        hi = p;
                                    else
                                        hiTail.next = p;
                                    hiTail = p;
                                    ++hc;
                                }
                            }
                            ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
                                (hc != 0) ? new TreeBin<K,V>(lo) : t;
                            hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
                                (lc != 0) ? new TreeBin<K,V>(hi) : t;
                            setTabAt(nextTab, i, ln); // 新数组设置头结点
                            setTabAt(nextTab, i + n, hn); // 新数组设置头结点
                            setTabAt(tab, i, fwd); // 原数组设置移动标识
                            advance = true;
                        }
                    }
                }
            }
        }
    }

扩容的代码有点多,就不详细描述了:

  1. 扩容是新建一个新的数组,将原数组元素放进新数组中,流程大概和HashMap的扩容差不多;
  2. 使用casForwardingNode放入原数组对应的节点中,标志已经移动;具体看fwd元素的使用地方;
  3. 关于synchronized代码块,具体可以先看HashMap的扩容,两者是差不多的,这里就不细说了;

扩展

这个方法还有一点就是while代码块,这个地方就是实现加速扩容的地方;下面具体说:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
            while (advance) {
                int nextIndex, nextBound;
                // 判断该阶段是否完成
                if (--i >= bound || finishing)
                    advance = false;
                else if ((nextIndex = transferIndex) <= 0) {
                    i = -1;
                    advance = false;
                }
                // 计算下一阶段
                else if (U.compareAndSwapInt(this, TRANSFERINDEX, nextIndex,nextBound = (nextIndex > stride ? nextIndex - stride : 0))) {
                    bound = nextBound;
                    i = nextIndex - 1;
                    advance = false;
                }
            }

这个地方主要是进行i值计算,也就是数组(桶)的下标; 通过(transferIndex-stride)来将原数组分段;多线程的情况下通过cas来保证线程安全;

涉及参数: transferIndex:转移元素的索引;就是看转移数据转移到什么位置了; stride:每次转移的数量,这个值是根据CPU来计算的,最小值是16;

也就是当数组长度大于16时,才会分段加速;

获取方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
    public V get(Object key) {
        Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
        // 计算hash
        int h = spread(key.hashCode());
        // 判断是否初始化且hash对应桶中有数据
        if ((tab = table) != null && (n = tab.length) > 0 &&(e = tabAt(tab, (n - 1) & h)) != null) {
            // 判断是否为第一个
            if ((eh = e.hash) == h) {
                if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                    return e.val;
            }   
            else if (eh < 0) // 判断是否已经扩容
                return (p = e.find(h, key)) != null ? p.val : null;
            while ((e = e.next) != null) { 直接获取
                if (e.hash == h &&
                    ((ek = e.key) == key || (ek != null && key.equals(ek))))
                    return e.val;
            }
        }
        return null;
    }

获取方法其实不难,注意点就是判断是否扩容那一步,这个find方法是在ForwardingNode类中;而不是Node类中的;

其他方法

remove方法

1
2
3
4
5
6
7
8
9
10
    // 通过key删除
    public V remove(Object key) {
        return replaceNode(key, null, null);
    }
    // 通过key和value删除
    public boolean remove(Object key, Object value) {
        if (key == null)
            throw new NullPointerException();
        return value != null && replaceNode(key, null, value) != null;
    }

上面两个方法都调用了replaceNode方法;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    final V replaceNode(Object key, V value, Object cv) {
        // 计算hash
        int hash = spread(key.hashCode());
        for (Node<K,V>[] tab = table;;) { // 循环
            Node<K,V> f; int n, i, fh;
            // 判断是否初始化,和key对应桶是否为空
            if (tab == null || (n = tab.length) == 0 ||
                (f = tabAt(tab, i = (n - 1) & hash)) == null)
                break;
            // 判断是否扩容
            else if ((fh = f.hash) == MOVED)
                tab = helpTransfer(tab, f);
            else {
                V oldVal = null;
                boolean validated = false;
                synchronized (f) { // 锁住
                    if (tabAt(tab, i) == f) {
                        if (fh >= 0) { // 链表遍历
                            validated = true;
                            for (Node<K,V> e = f, pred = null;;) {
                                K ek;
                                if (e.hash == hash &&
                                    ((ek = e.key) == key ||
                                     (ek != null && key.equals(ek)))) {
                                    V ev = e.val;
                                    if (cv == null || cv == ev ||
                                        (ev != null && cv.equals(ev))) {
                                        oldVal = ev;
                                        if (value != null)
                                            e.val = value;
                                        else if (pred != null)
                                            pred.next = e.next;
                                        else
                                            setTabAt(tab, i, e.next);
                                    }
                                    break;
                                }
                                pred = e;
                                if ((e = e.next) == null)
                                    break;
                            }
                        }
                        else if (f instanceof TreeBin) { // 红黑树
                            validated = true;
                            TreeBin<K,V> t = (TreeBin<K,V>)f;
                            TreeNode<K,V> r, p;
                            if ((r = t.root) != null &&
                                (p = r.findTreeNode(hash, key, null)) != null) {
                                V pv = p.val;
                                if (cv == null || cv == pv ||
                                    (pv != null && cv.equals(pv))) {
                                    oldVal = pv;
                                    if (value != null)
                                        p.val = value;
                                    else if (t.removeTreeNode(p))
                                        setTabAt(tab, i, untreeify(t.first));
                                }
                            }
                        }
                    }
                }
                if (validated) {
                    // 判断key是否存在
                    if (oldVal != null) {
                        if (value == null)// 判断是否修改容量值
                            addCount(-1L, -1);
                        return oldVal;
                    }
                    break;
                }
            }
        }
        return null;
    }

sumCount方法

这个方法是对容器中的元素进行计算;

这里主要是想说明分段cas添加的数据是保存在counterCells中的;这个情况主要发生在多线程添加冲突的情况下

1
2
3
4
5
6
7
8
9
10
11
12
13
    final long sumCount() {
        // 获取counterCells值
        CounterCell[] as = counterCells; CounterCell a;
        // baseCount(未竞争的情况下是在baseCount中的)
        long sum = baseCount;
        if (as != null) { 
            for (int i = 0; i < as.length; ++i) { // 遍历counterCells
                if ((a = as[i]) != null)
                    sum += a.value; // 累加
            }
        }
        return sum;
    }

最后

参考

  1. Java魔法类:Unsafe应用解析
  2. Java进阶(六)从ConcurrentHashMap的演进看Java多线程核心技术
  3. 并发编程——ConcurrentHashMap#helpTransfer() 分析
  4. ConcurrentHashMap 源码阅读小结
坚持原创技术分享,您的支持将鼓励我继续创作!