Java 的源代码学习(3)——ConcurrentSkipListMap

“Java 的源代码学习”系列

(1)基本类型和对应的类

(2)HashMap 和 ConcurrentHashMap

(3)ConcurrentSkipListMap

乃可能会说:“乃为了准备面试也是拼了,看了这么多 JDK 的源代码!”实际上以老夫以前面试的经验(以及别人面试的反馈),老夫在研究的 JDK 源代码基本上不会有任何可能被问到。比如上次研究的 ConcurrentHashMap,面试官问了 HashMap 以及 Hashtable 的实现细节,当时我就很好奇如果要它实现高并发的哈希表会怎么去做,于是试探性地提了一下,结果他除了想到分段锁其他毛都没有了解!这样看来要他想一个办法去实现低延迟的扩容也是太难为他了!(乃会说乃面试的时候还问面试官问题?实际上老夫会提一个和官方实现相近的方案,然后试探性地问他有没有更好的方法。)

而实际上写 JDK 源代码的人水平并不一定比乃高到哪里去。 M$ 写的类库也有很多地方实现得相当差,但总体上比 JDK 要好一些。比如 StringBuilder 的 M$ 实现(老夫以前也写过一篇博客分析它的源代码)就比 JDK 的要好很多。JDK 的实现就是基本照抄的 ArrayList,实在是太偷懒了!而 StringBuffer 直接加了一堆 synchronized 就完事了!Java 从 1.6~1.8 版本开始很多类库都请了知名了大师(比如 Doug Lea 这个人)进行重写,质量提升了一大截,所以还是相当值得一读的。

这里老夫要来研究一下 JDK 1.6 开始的 ConcurrentSkipListMap,也就是大家都很熟悉的跳表。偶们来跟着世界上顶尖的并发编程大师来学习一下如何实现一个高并发的跳表。

强烈建议乃对 ConcurrentHashMap (Java 1.8 以后的 CAS 实现)相当熟悉后再来看这个类,不然会有点吃力。如果乃不太熟悉,乃可以参考这个系列博文的第(2)篇。

为什么 JDK 要提供一个 SkipList 的实现,不是已经有用红黑树的 TreeMap 实现吗?TreeMap 并不是线程安全的。那为啥 JDK 没有提供线程安全的 TreeMap 或者平衡二叉树呢?

实际上到目前为止,人类还没有找到在树上增加和删除结点的无锁的高效实现方法。而用链表作为基本数据结构的 SkipList,是可以实现 lock-free 的。由于 SkipList 在绝大多数情况下性能可以与平衡二叉树相媲美,而实现又比平衡二叉树简单很多,因此哪来做 thread-safe 的 TreeMap 似乎是一个比较好的选择。

ConcurrentSkipListMap 的类声明如下:

public class ConcurrentSkipListMap<K,V> extends AbstractMap<K,V>
    implements ConcurrentNavigableMap<K,V>, Cloneable, Serializable {
    ...
}

AbstractMap 实际上只是简单地实现了一下 Map 接口,然后给出了一些简单的实现,没什么好看的。ConcurrentNavigableMap 接口继承了 ConcurrentMap 以及 NavigableMap 两个接口,这两个接口的描述可以参考 Java Collections Framework 这个文档。

简单来说,ConcurrentMap 这个接口在 Map 的基础上增加了原子操作 putIfAbsent、remove 以及 replace。对于并发写操作,遵循 happen-before 原则,即如果 put 操作发生在另一个线程的 remove 之前,就保证 put 操作的可见性。

NavigableMap 从名字上看起来很诡异,Map 本身不就是 Navigable 的么?实际上这里是指可以按照范围查找 Map 中的一些元素,比如可以查找比某个 key 小的最大 key、可以查找某个区间内的所有元素组成的子 Map,等等。

下面进入正题。

如上图所示(图片 copy 自维基百科),跳表有头结点(本类称之为 HeadInxex)、索引节点(Index)两类。索引节点组织成一个链表,而不同节点还需要一个向下的指针域,因此 Index 类的定义如下:

    static class Index<K,V> {
        final Node<K,V> node;
        final Index<K,V> down;
        volatile Index<K,V> right;
        ......
    }

那么我们怎么设计一个可以并发设计的单链表呢?假如我们有如下链表:

对于新增的结点 20,我们只需用 CAS 操作吧 结点 10 的 next 修改成 20 就可以了。但是对于删除操作,就没有这么简单了:

如上图所示,我们需要删除 10 这个结点。假设此时有另外一个线程同时执行上面我们做的“在 10 与 30 之间插入 20 的操作”,就会出现问题:删除的 CAS 操作并不知道此时有结点操作,在删除操作执行完成的时候(H 的 next 已经指向了 30),新插入的 20 这个结点就消失了。

为此,Tim Harris 这个人在 2001 年提出了一个方案(参见这篇论文,这些的示意图也盗用自这篇论文):把删除操作分成两步,第一步把要删除的结点的 next 指针标记为“已删除”;第二步,找一个机会删除掉被标记的结点。

对于第一步,插入操作需要判断结点 10 是否已经被删除,如果不是,就可以把 10 的 next 指针指向 20。请注意,这个操作的判断、修改两个步骤必须要是原子的!否则就会面临上面所说的问题。

JDK 提供了一个 AtomicMarkableReference<T> 类,内部包装了一个 Pair<T> 对象,除了指向 T 的引用外,附带了一个 boolean 类型的标记字段,使用这个标记来表示结点的 next 指针是否有效,就可以达到目的了。其 compareAndSet 方法如下:

    public boolean compareAndSet(V       expectedReference,
                                 V       newReference,
                                 boolean expectedMark,
                                 boolean newMark) {
        Pair<V> current = pair;
        return
            expectedReference == current.reference &&
            expectedMark == current.mark &&
            ((newReference == current.reference &&
              newMark == current.mark) ||
             casPair(current, Pair.of(newReference, newMark)));
    }

然而,对于 ConcurrentSkipListMap,作者并不想设置一个额外的字段来标记某个结点是否被删除。因为这就意味着所有结点的指针域都要附加一个 boolean 类型(注意作为字段的 boolean 除非在是一个数组,否则需要占用 4 个字节,而不是 1 位!)。此外,附加这个标记位也需要额外的性能开销(比如乃看上面的 CAS 代码),JDK 跳表的作者 Doug Lea 这个类的内存与性能开销太大,因为对于大部分情况会处于“被删除”的结点是少数,如果每一个都加额外的字段,在遍历时每一个都进行判断,确实没有必要。

于是,Doug Lea 想了一个办法,在删除结点时,在被删除结点的后面放置一个 marker 结点,同时将被删除结点的 value 设置为 null。在遍历时如果看到 marker 结点就知道前一个结点已经被标记位删除。

采用这样的办法,删除操作就分成了以下几步:

  1. 将 n 的 value 用 CAS 修改成 null。如果 CAS 失败,重试即可。注意其他操作仍然有可能在此时修改 n 的 next 指针。
  2. 将 n 的 next 指针用 CAS 指向新生成的 marker 结点。之后,其他线程便【不能】修改 n 的 next 指针为新增的结点了(因为新增之前要判断 n 的下一个结点是否为 marker 结点,我们会在代码里面看到这一段)。
  3. 将 b 的 next 直接指向 marker 的下一个结点 f。此时 n 和 marker 便不可达,可以被垃圾回收了。

注意这里的第一步完成后不会马上进行 2、3 步,而是在下一次遍历时进行。如果第 2、3 步 CAS 失败,说明其他线程可能也在尝试删除 n 和 marker。

对于具体的含义,偶们边看代码边分析。本类中,marker 结点的特性就是 value 指针指向自己(因为 Java 的“假”泛型,key 和 value 会被类型消除为 Object,正好利用了这个蛋疼的特性!)。实际上为了节省内存,可以专门为 marker 结点定义一个子类并去掉 value 字段,但 JDK 暂时没有使用这样的优化,因为对于大部分情况来说,marker 结点不会很多,并且大多也会很快被删除。

在添加某个结点前,需要找到前面一个结点。先看看 findPredecesoor 方法:

    private Node<K,V> findPredecessor(Object key, Comparator<? super K> cmp) {
        if (key == null)
            throw new NullPointerException(); // don't postpone errors
        for (;;) {
            for (Index<K,V> q = head, r = q.right, d;;) {
                if (r != null) {
                    Node<K,V> n = r.node;
                    K k = n.key;
                    if (n.value == null) {
                        if (!q.unlink(r))
                            break;           // restart
                        r = q.right;         // reread r
                        continue;
                    }
                    if (cpr(cmp, key, k) > 0) {
                        q = r;
                        r = r.right;
                        continue;
                    }
                }
                if ((d = q.down) == null)
                    return q.node;
                q = d;
                r = d.right;
            }
        }
    }

代码相当简单,一眼就能看懂:判断右侧结点的 key 是否比待查找的大,如果是,则当前结点右移,否则下移一层缩小范围进一步查找。与单线程的跳表不同的是,这里需要处理 n.value == null 的情况,也就是上面删除链表操作中标记为“已删除”的结点。q.unlink 方法如果失败,说明其他线程已经把这个 Index 给删掉了,此时对于“右侧”索引的判断已经失效,break 掉重新取 head 进行查找。如果成功,此时新的右结点变成被删结点的下一个结点。Index 的 unlink 方法如下:

        final boolean unlink(Index<K,V> succ) {
            return node.value != null && casRight(succ, succ.right);
        }

酱紫就可以开始看 doPut 方法了。

    private V doPut(K key, V value, boolean onlyIfAbsent) {
        Node<K,V> z;             // added node
        if (key == null)
            throw new NullPointerException();
        Comparator<? super K> cmp = comparator;
        outer: for (;;) {
            for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {
                if (n != null) {
                    Object v; int c;
                    Node<K,V> f = n.next;
                    if (n != b.next)               // inconsistent read
                        break;

乃可以看到又是这种双重死循环。对于 b.next 进行了两次判断,当 b.next 不为 n 时,此时我们假设的 b -> n -> f 关系已经被其他线程改了,无法进行接下来的删除操作。

                    if ((v = n.value) == null) {   // n is deleted
                        n.helpDelete(b, f);
                        break;
                    }

如果待插入结点的右侧结点已经被删除,调用 helpDelete 方法:

        void helpDelete(Node<K,V> b, Node<K,V> f) {
            /*
             * Rechecking links and then doing only one of the
             * help-out stages per call tends to minimize CAS
             * interference among helping threads.
             */
            if (f == next && this == b.next) {
                if (f == null || f.value != f) // not already marked
                    casNext(f, new Node<K,V>(f));
                else
                    b.casNext(this, f.next);
            }
        }

回顾一下最开始提到的删除方法:先把当前结点的值置为 null,然后在后面插入一个 marker 结点,最后把前继结点的 next 设置为 marker 结点的下一个结点。对于 marker 结点,此类采用 value == this 进行判断(<K, V> 在类型消除后都是 Object 类型!)。

  • 如果当前结点的下一个结点是 next 字段的值,并且当前结点是前一结点 b 的下一个结点,说明在此期间没有其他线程同时进行 delete,否则就不去管这个删除操作,交给正在进行的线程继续完成好了。
  • 如果下一个结点不是 marker 结点(空结点或其他 value != this 的结点),则新建一个 maker 结点放在被删除结点的右边。
  • 如果下一个结点已经是 marker 结点,此时可以安全地把前一结点的后继置为 marker 结点的下一结点。

在 helpDelete 后,无论成功或者失败,都说明待插入位置的右边有元素变动,需要重新取一下 b 后侧的结点。乃可能会说,这里执行完第一步删除的标记工作后为啥不继续把前继结点的 next 改成目标的就可以了?实际上这和在下一个 for 循环里重新取一次再判断效果是差不多的,因为其他线程可能也在调用这个方法,还是要在这里写一个 for 继续重试。

继续看 doPut 方法:

                    if (b.value == null || v == n) // b is deleted
                        break;

这里应对的是待插入结点的前继结点被删除的情况。v == n 这个判断可能有点绕,实际上是 n.value == n,表示 n 结点是一个 marker 结点,而 marker 结点前面就是待删除的结点了(因此 b.value == null 这个检查是可以去掉的)。这会发生在什么时候呢?比如我们在取出 b 后,另一个线程删除了 b。在此情况下,我们【不能】helpDelete 前一个结点 b,因为此时我们不知道 b 的前一个结点是什么,那么我们最好是“期待”下一次调用 findPredecessor 时找到 b 的前一个结点。

                    if ((c = cpr(cmp, key, n.key)) > 0) {
                        b = n;
                        n = f;
                        continue;
                    }

接下来,如果 n 结点有效,我们比较 n 结点的 key 是否比要插入结点的 key 大,如果是,链表向后移动一个位置。乃说,findPredecessor 不是已经找到比它小的最大 key 了么?说得好有道理,不过在此期间已经有其他结点插入了怎么办?

                    if (c == 0) {
                        if (onlyIfAbsent || n.casValue(v, value)) {
                            @SuppressWarnings("unchecked") V vv = (V)v;
                            return vv;
                        }
                        break; // restart if lost race to replace value
                    }
                    // else c < 0; fall through
                }

如果 n 结点的 key 与要插入的相同,根据参数判断是否要更新即可。

                z = new Node<K,V>(key, value, n);
                if (!b.casNext(n, z))
                    break;         // restart if lost race to append to b
                break outer;
            }
        }

如果一切 OK,就尝试在 b 的后面插入一个新结点。如果成功插入,则大功告成!break 掉外层循环。至此,我们可以看出,由于引入了 marker 标记,原先需要处理的 b -> n -> f 序列变成了 b -> n -> marker -> f 序列,算法会稍微复杂一些。

所以 doPut 就看完了?图样图森破!乃忘记了我们还要更新蛋疼的 Index 结点。所以继续。先回顾一下,在跳表插入结点后,对于每一层,都有一定的概率出现在更高一层。

        int rnd = ThreadLocalRandom.nextSecondarySeed();
        if ((rnd & 0x80000001) == 0) { // test highest and lowest bits
            int level = 1, max;
            while (((rnd >>>= 1) & 1) != 0)
                ++level;

乃看,如果随机数的最高位和最低位有一个不为 0,就不会进入 if 分支,此时结点不增加新的 Index(3/4 的概率)。其次,对于随机数的其他,如果为 1 则 level++,也就是说每层会有 50% 的概率出现在更高一层。

            Index<K,V> idx = null;
            HeadIndex<K,V> h = head;
            if (level <= (max = h.level)) {
                for (int i = 1; i <= level; ++i)
                    idx = new Index<K,V>(z, idx, null);
            }

如果跳表已经有这么多层了,每层新建一个 Index 结点就好啦。

            else { // try to grow by one level
                level = max + 1; // hold in array and later pick the one to use
                @SuppressWarnings("unchecked")Index<K,V>[] idxs =
                    (Index<K,V>[])new Index<?,?>[level+1];
                for (int i = 1; i <= level; ++i)
                    idxs[i] = idx = new Index<K,V>(z, idx, null);
                for (;;) {
                    h = head;
                    int oldLevel = h.level;
                    if (level <= oldLevel) // lost race to add level
                        break;
                    HeadIndex<K,V> newh = h;
                    Node<K,V> oldbase = h.node;
                    for (int j = oldLevel+1; j <= level; ++j)
                        newh = new HeadIndex<K,V>(oldbase, newh, idxs[j], j);
                    if (casHead(h, newh)) {
                        h = newh;
                        idx = idxs[level = oldLevel];
                        break;
                    }
                }
            }

不然就要新增一层了。不过,考虑到其他执行 doPut 的线程也在干这种事,所以要用 CAS 来设置新的 head。如果其他线程在此时已经把层数增加到 level 以上的话,break 掉就好了。

当层数已经达到要求后,就需要在每一层找一个合适的位置插入 Index 结点。

            splice: for (int insertionLevel = level;;) {
                int j = h.level;
                for (Index<K,V> q = h, r = q.right, t = idx;;) {
                    if (q == null || t == null)
                        break splice;
                    if (r != null) {
                        Node<K,V> n = r.node;
                        // compare before deletion check avoids needing recheck
                        int c = cpr(cmp, key, n.key);
                        if (n.value == null) {
                            if (!q.unlink(r))
                                break;
                            r = q.right;
                            continue;
                        }
                        if (c > 0) {
                            q = r;
                            r = r.right;
                            continue;
                        }
                    }

对于每个 Index,都要像 Node 那样处理被删除的情况(q.unlink),如果 unlink 失败,说明其他线程已经把它干掉了,break 掉再来一次。与插入 Node 相同,还要考虑其他线程在此处增加结点的情况,所以还要进行 c > 0 的判断。

                if (j == insertionLevel) {
                    if (!q.link(r, t))
                        break; // restart
                    if (t.node.value == null) {
                        findNode(key);
                        break splice;
                    }
                    if (--insertionLevel == 0)
                        break splice;
                }

                if (--j >= insertionLevel && j < level)
                    t = t.down;
                q = q.down;
                r = q.right;
            }
        }
    }
    return null;
}

达到了 level 这一行,就要开始把 Index 结点附加上去。如果在此时其他线程又把这个新增的结点删掉了(t.node.value == null),也就没必要再新增其他 Index 了。调用一次 findNode,遍历一遍相关的 Index 和 Node,“前功尽弃”就可以了。

折腾了老半天,doPut 总算看完了。doGet 和 findNode 这两个方法与 doPut 中刚开始的双重 for 循环类似,也需要处理结点删除的情况,代码大致类似,就不贴了。可以看到,在所有的查找操作,都要处理删除的情况(包括在新增 Index 的过程中,结点在此时有可能被删除)。

接下来开始看 doRemove,doRemove 时,首先也要找到对应结点,类似于 doPut,需要处理删除,在这里就不费口水了。

    final V doRemove(Object key, Object value) {
        if (key == null)
            throw new NullPointerException();
        Comparator<? super K> cmp = comparator;
        outer: for (;;) {
            for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {
                Object v; int c;
                if (n == null)
                    break outer;
                Node<K,V> f = n.next;
                if (n != b.next)                    // inconsistent read
                    break;
                if ((v = n.value) == null) {        // n is deleted
                    n.helpDelete(b, f);
                    break;
                }
                if (b.value == null || v == n)      // b is deleted
                    break;
                if ((c = cpr(cmp, key, n.key)) < 0)
                    break outer;
                if (c > 0) {
                    b = n;
                    n = f;
                    continue;
                }

与 doPut 有点不太一样的是,这里判断 key 与 n,如果 n.key < key,说明相关结点已经被删除掉了,return null 即可。

                if (value != null && !value.equals(v))
                    break outer;

这个方法接受一个 value 参数,可以做到只有 value 也相符的时候才删除结点。这个在并发处理时很有用,乃想一下用 redis 实现分布式锁,remove 的时候如果不去检查 value,可能把别的线程加的锁给 remove 掉了!

                if (!n.casValue(v, null))
                    break;
                if (!n.appendMarker(f) || !b.casNext(n, f))
                    findNode(key);                  // retry via findNode
                else {
                    findPredecessor(key, cmp);      // clean index
                    if (head.right == null)
                        tryReduceLevel();
                }
                @SuppressWarnings("unchecked") V vv = (V)v;
                return vv;
            }
        }
        return null;
    }

真正设置 value 为 null 以及添加 marker 的步骤很好理解。如果第一层只有一个 Index 结点了(再判断之前先要把待删除的 Index 都清理干净),就可以把它干掉,调用 tryReduceLevel:

    private void tryReduceLevel() {
        HeadIndex<K,V> h = head;
        HeadIndex<K,V> d;
        HeadIndex<K,V> e;
        if (h.level > 3 &&
            (d = (HeadIndex<K,V>)h.down) != null &&
            (e = (HeadIndex<K,V>)d.down) != null &&
            e.right == null &&
            d.right == null &&
            h.right == null &&
            casHead(h, d) && // try to set
            h.right != null) // recheck
            casHead(d, h);   // try to backout
    }

请注意,这个方法是有问题的!h、d、e 分别表示第一、二、三行,如果这三行都为空,会尝试把 head 设置为第二行 d。但是!此时其他线程可能恰好在判断完之后,与 casHead(h, d) 执行之前,新增了 Index 结点!此时 casHead 就“冤枉”了 h 结点,因为恰好它的 right 又有新的结点产生,不应该被删除掉。

可以看到作者在 casHead 之后再判断了一次 h.right,如果 h 的右侧又新增了结点,撤销这次操作把 h 再放回去。但是这仍有可能有问题,因为你的判断和操作仍然不是原子的,再第二次 CAS 之前,可能又会有新的线程来“捣乱”。

根据代码中的注释,作者认为出现“误操作”的情况相当罕见,因为它需要恰好在判断完 h、d、e 后准备 casHead 之前被中断,另一线程恰好要把高度增加到第一行(如果层级比较多,后者的概率本身就不高)。增加了第二次判断后,出问题的可能性就在第二次 CAS 之前再次被打断,连续两次出现这种情况可能性就更小了。如果出现这种情况,那只能说明乃脸黑,作者也救不了你!

不过好在这个 Index 层数少了只会在一定程度上影响跳表的性能。此外,对于一个元素比较少的跳表如果频繁有插入和删除操作,可能会产生大量层级,反而更影响性能。所以偶然删掉一个层级貌似也不全是坏事。

至此,关于 ConcurrentSkipListMap 的重点内容偶就已经学习完了。

最后老夫要来唠叨一下,有些人可能觉得如此地优化内存占用以及性能没有太大必要。乃可能会说,虽然这样优化内存占用,但是好像然并卵。其实作为 JVM 的底层实现以及类包,对内存确实是锱铢必较的,每一个字节甚至是每一位都是能省则省。但是,随着软件越来越复杂,占用的内存也不是 JVM 省的那一些内存能控制的。尤其是 UI,那占的内存真是超级大(比如老夫曾经研究过 M$ 是怎么把控件的事件字段“压缩”起来省内存的,与此类似的还有 WPF 的依赖属性的设计,不过好像 WPF 被 M$ 自己玩死了)。其次软件的开发者水平也层次不齐,就比如某里系的那些垃圾 app,随意拿个 dump 分析下它的内存占用以及 UI 绘制就会发现实现得差得不能再差(而某里也很喜欢吹嘘自己技术强,老夫也是醉了!)。Google 好不容易在 Android 7.0 引入了 profile guided JIT 来优化内存占用,具体的数字我不知道,我们假设有 5% 吧,可是某里随便加了一个不知道干啥的功能(关键是这些 app 的趋势就是越来越臃肿),就把你的内存多吃 100 MB,那么看上去 Google 的优化就没有一点用了!

不过作为 JVM 的基础类库,做到能省则省,也算是“仁至义尽”了,你说是吧?

ConcurrentSkipListMap 是 JDK 目前唯一提供的支持并发的 NavigableMap 的实现(然而 M$ 到目前为止似乎没有提供类似的实现)。当然虽然叫 Map,但其操作是 O(logn)的。

这个序列连续学习了两个使用 CAS 类的实现了,而且逻辑一个比一个复杂,相信大家都已经晕了。那么下一个系列老夫就要开始研究大家喜闻乐见的 Java 的锁机制的相关实现了,相比于这些晦涩难懂并且超级难写的 CAS,应该会轻松不少吧。呵呵呵呵……

✏️ 有任何想法?欢迎发邮件告诉老夫:daozhihun@outlook.com