赤黒木を実装してみた

赤黒木を実装してみた。

  • 本当は永続赤黒木を実装してみたくてまずは赤黒木からやることに。
  • merge/split ベースでやる
  • merge/split の中で new/delete をやるせいか、実装が悪いせいか、Treap の方が insert がずっと速い...orz
    • new/delete を boost pool にして高速化をしたら insert は 1.7 倍速、add/min なんかはこっちの方が速いのでこんなもんなのかも。確かに赤黒木のほうがより平衡化されているっぽい。(メモリ消費量に比べると微妙な高速化とか思ったり...。)
    • よく考えたら Treap とちがって葉にしか値をもたせていないので高さが +1 増えているのでそのせいかも。
  • いくつかの問題で verify 済み。
  • 最初は 赤黒木(marge/split) - Algoogle なんかを参考にしていたんだけど merge 部分の場合分けがよくわかんなくて結局下の論文っぽい実装に落ち着いた。

http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.109.4875&rep=rep1&type=pdf

  • 二つの赤黒木 T1, T2 をマージするとき rank が等しくなるまで木を下り、等しくなったら赤ノードを親にしてマージ。後は根まで上りながら(計算量は高さ O(a.rank-b.rank) 分)左右の black height が保ちながら、赤ノードが連続しないように回転させていく。(下の図は上の論文より)

f:id:shifth:20150510111601p:plain

  • merge/split が遅すぎなんで add/get/set/min は愚直な実装で。(Treap は split/merge だけでほぼ書けた)
  • add/min は区間を分割しながら再帰していく。区間を表す代表ノードに到達したらそこで止まるので計算量は O(log(N)) くらいになるはず。
  • なんか上の論文をよく読むと split はもっと改善できそうな気がしてきた。new/delete なくせそう。単なる赤黒木なら節以外に値を入れてメモリを節約できそう。(永続化するときは値は葉にもたせないとだめなのかはまだわからない。)
  • 15/5/20 追記
    • split 時に不要になったノードを reuse として使いまわすことで高速化。(コード更新)
      • node の再初期化にコンストラクタをつかったら、( *t = node(l, r, c) )一時変数のデストラクタがよばれてしまって子ノードを delete してしまうというバグに陥って悩んだ...。
    • 節にも値をもたせるなら最後にダミーノードを取り除かないとダメなんでそこで delete の実装が必要になりそう。
template <typename T>
struct node_t {
    enum Color : char {R=0, B=1, E=-1};// E は任意の色を表す
    T     val;
    T     lazy_val;
    T     min_val;
    Color color;
    int   rank;   // 葉までの黒ノードの数(自分自身はのぞく)
    int   cnt;    // 葉の数
    node_t *lch;
    node_t *rch;
    bool  lazy;

    // 葉のコンストラクタ
    node_t(T v) : val(v),
                  lazy_val(0),
                  min_val(v),
                  color(B),
                  rank(0),
                  cnt(1),
                  lch(nullptr),
                  rch(nullptr),
                  lazy(false) {}
    // 節のコンストラクタ
    node_t(node_t *l, node_t *r, Color c) : val(0),
                                            lazy_val(0),
                                            min_val(0),
                                            color(c),
                                            rank(0),
                                            cnt(0),
                                            lch(l),
                                            rch(r),
                                            lazy(false) {
        update(this);
    }
    ~node_t() {
        delete lch;
        delete rch;
    }
    static void push(node_t *t) {
            if (!t)
                return;
            if (t->lazy)
            {
                if (is_leaf(t))
                    t->val += t->lazy_val;
                t->min_val += t->lazy_val;
                if (t->lch)
                {
                    t->lch->lazy = true;
                    t->lch->lazy_val += t->lazy_val;
                }
                if (t->rch)
                {
                    t->rch->lazy = true;
                    t->rch->lazy_val += t->lazy_val;
                }
                t->lazy_val = 0;
                t->lazy = false;
            }
    }
    static node_t *update(node_t *t) {
        if (!t) return nullptr;
        push(t);

        if (is_leaf(t)) // 葉のとき
        {
            t->cnt = 1;
            t->min_val = t->val;
        }
        else // 節のとき
        {
            node_t *lch, *rch;
            lch = t->lch;
            rch = t->rch;
            // 片方 nullptr ということはないはず
            assert(rch!=nullptr);
            assert(lch!=nullptr);
            push(lch);
            push(rch);
            t->rank = std::max(black_height(lch),
                               black_height(rch));
            t->cnt = count(lch) + count(rch);
            t->min_val = std::min(minval(lch), minval(rch));
        }
        return t;
    }
    static int black_height(node_t *t) { return t? t->rank+(t->color==B) : 0; }
    static int count(node_t *t) { return t? t->cnt : 0; }
    static int minval(node_t *t) { return t? t->min_val : std::numeric_limits<T>::max(); }
    static bool is_leaf(node_t *t) {
        return (t && !t->lch && !t->rch);
    }
    static node_t *new_node(T v) {
        return new node_t(v);
    }
    static node_t *new_node(node_t *l, node_t *r, Color c, node_t *work=nullptr) {
        if (work)
        {
            // 一時オブジェクト削除時にデストラクタが呼ばれて子ノードが
            // 削除されないように nullptr で初期化
            *work = node_t(nullptr, nullptr, c);
            work->lch = l;
            work->rch = r;
            return update(work);
        }
        return new node_t(l, r, c);
    }
    static void remove_node(node_t *t) {
        if (!t) return;
        // デストラクタで子ノードが削除されないように nullptr を設定しておく
        t->lch = t->rch = nullptr;
        delete t;
    }
    static node_t *set_colors(node_t *t, Color c, Color lc=E, Color rc=E) {
        if (!t) return t;
        if (c!=E)  t->color = c;
        if (lc!=E) t->lch->color = lc;
        if (rc!=E) t->rch->color = rc;
        return update(t);
    }
    static bool color_type(node_t *t, Color c, Color lc, Color rc) {
        bool match = true;
        if (c!=E)  match &= (t->color==c);
        if (lc!=E) match &= (t->lch->color==lc);
        if (rc!=E) match &= (t->rch->color==rc);
        return match;
    }
    // t             r
    //   r     =>  t   b
    //  a b         a
    static node_t *rotateL(node_t *t) {
        node_t *r = t->rch;
        assert(r != nullptr);
        push(r);
        t->rch = r->lch;
        r->lch = t;
        update(t);
        return update(r);
    }
    //    t        l
    //  l    =>  a   t
    // a b          b
    static node_t *rotateR(node_t *t) {
        node_t *l = t->lch;
        assert(l != nullptr);
        push(l);
        t->lch = l->rch;
        l->rch = t;
        update(t);
        return update(l);
    }
    static node_t *rotL(node_t *t) {
        return set_colors(rotateL(t),R,B,B);
    }
    static node_t *rotR(node_t *t) {
        return set_colors(rotateR(t),R,B,B);
    }
    static node_t *rotLR(node_t *t) {
        return set_colors(rotateR(rotateL(t)),R,B,B);
    }
    static node_t *rotRL(node_t *t) {
        return set_colors(rotateL(rotateR(t)),R,B,B);
    }

    // 子ノードの black height が同じはず
#define assert_equal_BH(A) assert(black_height((A)->lch)==black_height((A)->rch))

    static node_t *merge_sub(node_t *a, node_t *b, node_t *reuse) {
        push(a);
        push(b);
        if (a->rank < b->rank)
        {
            assert(!is_leaf(b));
            // rank の大きい方の子のうち他方に近い方をマージする
            b->lch = merge_sub(a, b->lch, reuse);
            update(b);

            // rank を保ったまま色を塗り直す
            assert_equal_BH(b);

            //1)  B  2)  B
            //   R      R
            //  R        R
            if ( color_type(b,B,R,E) && color_type(b->lch,R,R,E) )
                return rotR(b);
            else if ( color_type(b,B,R,E) && color_type(b->lch,R,E,R) )
                return rotLR(b);
            else
                return b; // それ以外は上のノードでの blance ににまかせる
        }
        else if (a->rank > b->rank)
        {
            assert(!is_leaf(a));
            a->rch = merge_sub(a->rch, b, reuse);
            update(a);

            assert_equal_BH(a);

            //3)  B     4)  B
            //     R         R
            //      R       R
            if ( color_type(a,B,E,R) && color_type(a->rch,R,E,R) )
                return rotL(a);
            else if ( color_type(a,B,E,R) && color_type(a->rch,R,R,E) )
                return rotRL(a);
            else
                return a;
        }
        else
        {
            // 新しい部分木の rank がマージ前のものとかわらないように赤色を設定する。
            // a,b が赤黒木なら
            //  R
            // R R のようなケースはできないはず
            assert(!(a->color==R && b->color==R));
            return new_node(a, b, R, reuse);
        }
        assert(false);
        return nullptr;
    }
    static node_t *merge(node_t* a, node_t* b, node_t *reuse=nullptr) {
        if (!a || !b)
            return (!a)? b : a;
        assert(a->color==B && b->color==B);
        node_t *c = merge_sub(update(a), update(b), reuse);
        return set_colors(c,B,E,E);
    }
    static std::pair<node_t*, node_t*> split(node_t *t, int k) {
        if (!t)
            return std::make_pair(nullptr, nullptr);
        push(t);
        if (k == 0)
            return std::make_pair(nullptr, t);
        if (k == count(t))
            return std::make_pair(t, nullptr);
        std::pair<node_t*, node_t*> res;
        node_t *lch = t->lch;
        node_t *rch = t->rch;
        node_t *reuse = t ;
        if (k < count(lch))
        {
            node_t *l, *r;
            std::tie(l,r) = split(lch, k);
            // merge を呼び出すときは赤黒木になるように根を黒色にしておく
            res = std::make_pair(l, merge(r, set_colors(rch,B), reuse));
        }
        else if (k > count(lch))
        {
            node_t *l, *r;
            std::tie(l,r) = split(rch, k-count(lch));
            res = std::make_pair(merge(set_colors(lch,B), l, reuse), r);
        }
        else
        {
            res = std::make_pair(set_colors(lch,B), set_colors(rch,B));
            remove_node(reuse);         // 不要なノードは削除
        }
        return res;
    }
    static node_t *insert(node_t *t, int k, node_t *new_t) {
        auto s = split(t, k);
        return merge(merge(s.first, new_t), s.second);
    }
    static std::pair<node_t*, node_t*> erase(node_t *t, int k) {
        auto sr = split(t, k+1);
        auto sl = split(sr.first, k);
        return std::make_pair(merge(sl.first, sr.second), sl.second);
    }
    static node_t* at(node_t *t, int k) {
        assert(t!=nullptr);
        update(t);
        if (is_leaf(t))
            return t;
        if (k < count(t->lch))
            return at(t->lch, k);
        else
            return at(t->rch, k-count(t->lch));
    }
    // at [l,r)
    // 与えられた区間を分割しながら再帰する
    // O(log(N))
    static void add(node_t *t, int l, int r, T val) {
        if (r-l <= 0)
            return;
        // 区間を表すノードにたどりつけば止める
        if(count(t) == r-l)
        {
            t->lazy_val += val;
            t->lazy = true;
            update(t);
            return;
        }
        add(t->lch, l, std::min(r, count(t->lch)), val);
        add(t->rch, l-std::min(l, count(t->lch)), r-count(t->lch), val);
        // パスを上りながらノードを更新していく
        update(t);
    }
    // at [l,r)
    // 与えられた区間を分割しながら再帰する
    // O(log(N))
    static T min(node_t *t, int l, int r) {
        if (r-l <= 0)
            return minval(nullptr);
        push(t);
        // 区間を表すノードにたどりつけば止める
        if(count(t) == r-l)
            return minval(t);
        push(t->lch);
        push(t->rch);
        return std::min(min(t->lch, l, std::min(r, count(t->lch))),
                        min(t->rch, l-std::min(l, count(t->lch)), r-count(t->lch)));
    }
};

#include <random>
#include <functional>

template <typename T>
class RBTree {
    typedef node_t<T> node;
public:
    node *root_;
    RBTree() : root_(nullptr) {
    }
    ~RBTree() {}
    void insert(int k, T val) {
        assert(0<=k && k<=size());
        root_ = node::insert(root_, k, node::new_node(val));
    }
    void erase(int k) {
            node *p;
            std::tie(root_, p) = node::erase(root_, k);
            node::remove_node(p);
    }
    int size() const { return node::count(root_); }
    // add to [l,r)
    void add(int l, int r, T val) {
            assert(l<r);
            node::add(root_, l, r, val);
    }
    // min at [l,r)
    T min(int l, int r) {
            assert(l<r);
            return node::min(root_, l, r);
    }
    T at(int k) {
        assert(0<=k && k < size());
        auto p = node::at(root_, k);
        return p->val;
    }
    T get(int k) {
            if (0<=k && k < size())
                return at(k);
            return 0;
    }
    void set(int k, T val) {
        assert(0<=k && k < size());
        add(k, k+1, val-at(k));
    }
    node *build(const std::vector<T> &vec) {
        if (vec.empty())
            return (root_ = nullptr);
        int n = vec.size();
        if (n == 1)
            return (root_ = node::new_node(vec[0]));
        int m = n/2;
        return (root_ = node::merge(build(std::vector<T>(std::begin(vec), std::begin(vec)+m)),
                                    build(std::vector<T>(std::begin(vec)+m, std::end(vec)))));
    }
};