[模板] 替罪羊树与Treap

引言

二叉查找树(二叉排序树)是满足左子树节点权值小于根节点,右子树节点权值大于根节点,左右子树为二叉排序树的一种数据结构,可以方便的维护一个序列,实现查询排名,前驱后继等操作

但普通的二叉查找树在进行插入与删除操作后,可能会导致子树大小失衡,因此最坏复杂度依旧是 $O(n)$,之前我们已经介绍了 $Splay$ 的相关实现和运用,它是通过旋转与 $Splay$ 操作来保证复杂度

替罪羊树

替罪羊树(重量平衡树)通过重建来 保证子树的平衡,在插入过程结束后,通过平衡因子来保证左子树大小与右子树大小近似相等,当失衡时则强制重建,删除则采用懒惰删除的方式。

数据储存与更新

因为替罪羊树使用了懒惰删除的方式,因此需要还需要储存有效节点数和总结点数。

struct Node{
    int son[2];
    int key; //键值
    int siz; //有效节点数
    int tot; //总节点数
    int cnt; //键值出现次数

    Node(){key = siz = tot = son[0] = son[1] = cnt = 0;}
};

#define lson(x) node[x].son[0]
#define rson(x) node[x].son[1]

Node node[MAXN];
int root, now;

inline void update(int root){
    node[root].siz = node[lson(root)].siz + node[rson(root)].siz + node[root].cnt;
    node[root].tot = node[lson(root)].tot + node[rson(root)].tot + node[root].cnt;
}

判断是否失衡

这里需要一个平衡因子,一般取 $0.7$ 或 $0.75$,如果某棵子树占整棵树比例超过了这个平衡因子,就说明需要重构

inline bool check(int root){
    if(node[lson(root)].tot > node[root].tot * aplha + 3 || node[rson(root)].tot > node[root].tot * aplha + 3)
        return true;
    return false;   
}

重构

先对子树做中序遍历,将非空节点存入数组中(清除懒惰删除节点),最后以建立二叉树的方式不断选中值作为子树的根节点即可

inline void search(int root, vector <int> &vec){
    if(!root)
        return;

    search(lson(root), vec);

    if(node[root].cnt)
        vec.push_back(root);

    search(rson(root), vec);
}

inline int divide(vector <int> &vec, int l, int r){
    if(l > r)
        return 0;

    int mid = (l + r) >> 1;

    int pos = vec[mid];

    lson(pos) = divide(vec, l, mid-1);
    rson(pos) = divide(vec, mid+1, r);

    update(pos);

    return pos; 
}

inline void rebuild(int& root){
    vector <int> vec;
    vec.clear();

    search(root, vec);
    root = divide(vec, 0, vec.size() - 1);
}

插入

与正常的二叉查找树一样,找到插入的位置,若不存在节点则新建节点,若存在节点则修改计数。

同时,在插入时需要找到深度最小的一个失衡节点进行重构。

这里还使用了一个小技巧,null 指向了一个值为 $0$ 的 int 型变量,而非是一个空指针

int _null = 0;
int* null = &_null;

inline int* _insert(int& root, int key){
    if(!root){
        root = newnode(key);
        return null;
    }else{

        node[root].siz++;
        node[root].tot++;

        if(node[root].key == key){
            node[root].cnt++;
            return null;
        }else{
            int* pos = _insert(node[root].son[key > node[root].key], key);

            if(check(root)){
                pos = &root;
            }

            return pos;
        }
    }
}

inline void insert(int key){
    int* pos = _insert(root, key);

    if(*pos)
        rebuild(*pos);
}

删除

删除只需要在遍历到目标节点的过程中将有效节点个数依次修改即可。
当最后整棵平衡树的有效节点个数占总节点可数比例小于平衡因子时,重构整棵平衡树。

inline void _delete(int root, int key){
    int pos = root;

    while(true){
        node[pos].siz--;

        if(key < node[pos].key){
            pos = lson(pos);
            continue;
        } 

        if(node[pos].key == key){
            node[pos].cnt--;
            return;
        }

        pos = rson(pos);
    }       
}


inline void Delete(int key){
    _delete(root, key);

    if(node[root].siz < node[root].tot * aplha + 3)
        rebuild(root);
}

排名与第 $k$ 大

与正常的二叉查找树完全一致

inline int find(int key){
    int ans = 0;
    int pos = root;

    while(true){
        if(key < node[pos].key){
            pos = lson(pos);
            continue;
        } 

        ans += node[lson(pos)].siz;

        if(node[pos].key == key){
            return ans + 1;
        }

        ans += node[pos].cnt;
        pos = rson(pos);
    }
} 

inline int rank(int x){
    int pos = root;

    while(true){
        if(x <= node[lson(pos)].siz){
            pos = lson(pos);
            continue;
        }

        x -= node[lson(pos)].siz;

        if(x <= node[pos].cnt)
            return node[pos].key;

        x -= node[pos].cnt;
        pos = rson(pos);
    }
}

前驱与后继

因为使用了懒惰删除的方式,所以前驱和后继的查询稍微有一些不同。

对于数 $x$,它的前驱的排名为 $x$ 的排名减 $1$

具体做法为插入一个数 $x$,求出它的排名 $k$,删除数 $x$ 然后查询排名为 $k-1$ 的数的值

对于数 $x$,它的后继就是排名为小于等于 $x$ 的数的个数加 $1$

具体做法为插入一个数 $x+1$,求出它的排名 $k$,删除数 $x+1$,然后查询排名为 $k$ 的数的值

inline int pre(int key){
    insert(key);
    int p = rank(find(key) - 1);
    Delete(key);

    return p;
} 

inline int next(int key){
    insert(key+1);
    int id = find(key+1);
    Delete(key+1);

    int p = rank(id);
    return p;
}

Treap

$Treap = Tree + Heap$,树堆,是通过随机附加域来保证子树的平衡的。对于每一个节点都赋给它一个随机的 $priority$ 值,使得它是一棵以 $key$ 为主键的二叉查找树的同时,还是一棵以 $priority$ 为主键的堆。

维护 $priority$ 的方式和 $Splay$ 类似,是基于旋转的。

$Treap$ 的常数小,实现也简单,但是因为当 $key$ 和 $priority$ 给定时,树的形态也就唯一确定了,无法实现更多的复杂功能。

数据储存与更新

int tree[MAXN][2], siz[MAXN], cnt[MAXN];
int val[MAXN], rd[MAXN];
int tot;

inline void update(int root){
    siz[root] = siz[tree[root][0]] + siz[tree[root][1]] + cnt[root];
}

旋转

与 $Splay$ 的旋转非常近似,但是 $Treap$ 的旋转方向不是由它与父亲的关系确定的。

inline void rotate(int &root, int path){
    //path = 0 时左儿子旋上来,path = 1 时右儿子旋上来
    int son = tree[root][path];

    tree[root][path] = tree[son][path^1];
    tree[son][path^1] = root;

    update(root);
    update(son);

    root = son;
}

插入

插入也是定位到需要插入的节点上,如果不存在则新建节点,存在则修改出现次数。对于新建节点的情况,如果随机附加域破坏了堆的性质,那么就进行旋转即可。

inline void insert(int &root, int k){
    if(!root){
        root = ++tot;
        cnt[root] = siz[root] = 1;
        val[root] = k;
        rd[root] = rand();
        return;
    }

    siz[root]++;

    if(val[root] == k){
        cnt[root]++;
        return;
    }

    int tmp = val[root] < k;

    insert(tree[root][tmp], k);

    if(rd[root] > rd[tree[root][tmp]])
        rotate(root, tmp);
}

删除

删除比较特殊,先定位需要删除的节点(在定位过程中需要依次修改信息),然后就有三种情况

  • 没有左儿子右儿子
  • 有左儿子或有右儿子
  • 左右儿子都有

对于第一种情况,直接删除即可;对于第二种情况,直接将左儿子或右儿子放在节点的位置上;对于第三种情况,将随机附加域较小的(小根堆)旋上来,然后递归下去,直到出现第一种或第二种情况

inline void Delete(int &root, int k){
    if(!root)
        return;

    if(val[root] == k){
        if(cnt[root] > 1){
            cnt[root]--;
            siz[root]--;
            return;
        }

        int path = val[tree[root][0]] > val[tree[root][1]];

        if(tree[root][0] == 0 || tree[root][1] == 0)
            root = tree[root][0] + tree[root][1];
        else{
            rotate(root, path);
            Delete(root, k);
        }
    }else{
        siz[root]--;
        Delete(tree[root][val[root] < k], k);
    }
}

前驱与后继

递归的将查找目标节点过程中小于(大于)目标节点的值取最值即可。

inline int pre(int root, int k){
    if(!root)
        return -INF;

    if(val[root] >= k)
        return pre(tree[root][0], k);

    return max(pre(tree[root][1], k), val[root]);
} 

inline int next(int root, int k){
    if(!root)
        return INF;

    if(val[root] <= k)
        return next(tree[root][1], k);

    return min(next(tree[root][0], k), val[root]);
}

排名与第 $k$ 大

与其他各类平衡树没有区别,大家可以看下面的完整模板

完整模板

Luogu 3369 – 普通平衡树

另外,为了方便读者阅读与学习,一个使用了 $struct$ 封装了节点,一个没有使用。

替罪羊树

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>

using namespace std;

namespace Scapegoat{

    struct Node{
        int son[2];
        int key, siz, tot;  
        int cnt;

        Node(){key = siz = tot = son[0] = son[1] = cnt = 0;}
    };

    const long double aplha = 0.75;
    const int MAXN = 200000 + 10;
    const int INF = 0x3f3f3f3f;

    #define lson(x) node[x].son[0]
    #define rson(x) node[x].son[1]

    Node node[MAXN];
    int root, now;

    int _null = 0;
    int* null = &_null;

    inline void update(int root){
        node[root].siz = node[lson(root)].siz + node[rson(root)].siz + node[root].cnt;
        node[root].tot = node[lson(root)].tot + node[rson(root)].tot + node[root].cnt;
    }

    inline int newnode(int key){
        ++now;

        node[now].siz = node[now].tot = node[now].cnt = 1;
        node[now].key = key;
        return now;
    }

    inline bool check(int root){
        if(node[lson(root)].tot > node[root].tot * aplha + 3 || node[rson(root)].tot > node[root].tot * aplha + 3)
            return true;
        return false;   
    }

    inline void search(int root, vector <int> &vec){
        if(!root)
            return;

        search(lson(root), vec);

        if(node[root].cnt)
            vec.push_back(root);

        search(rson(root), vec);
    }

    inline int divide(vector <int> &vec, int l, int r){
        if(l > r)
            return 0;

        int mid = (l + r) >> 1;

        int pos = vec[mid];

        lson(pos) = divide(vec, l, mid-1);
        rson(pos) = divide(vec, mid+1, r);

        update(pos);

        return pos; 
    }

    inline void rebuild(int& root){
        vector <int> vec;
        vec.clear();

        search(root, vec);
        root = divide(vec, 0, vec.size() - 1);
    }

    inline int* _insert(int& root, int key){
        if(!root){
            root = newnode(key);
            return null;
        }else{

            node[root].siz++;
            node[root].tot++;

            if(node[root].key == key){
                node[root].cnt++;
                return null;
            }else{
                int* pos = _insert(node[root].son[key > node[root].key], key);

                if(check(root)){
                    pos = &root;
                }

                return pos;
            }
        }
    }

    inline void _delete(int root, int key){
        int pos = root;

        while(true){
            node[pos].siz--;

            if(key < node[pos].key){
                pos = lson(pos);
                continue;
            } 

            if(node[pos].key == key){
                node[pos].cnt--;
                return;
            }

            pos = rson(pos);
        }       
    }

    inline void insert(int key){
        int* pos = _insert(root, key);

        if(*pos)
            rebuild(*pos);
    }

    inline void Delete(int key){
        _delete(root, key);

        if(node[root].siz < node[root].tot * aplha + 3)
            rebuild(root);
    }

    inline int find(int key){
        int ans = 0;
        int pos = root;

        while(true){
            if(key < node[pos].key){
                pos = lson(pos);
                continue;
            } 

            ans += node[lson(pos)].siz;

            if(node[pos].key == key){
                return ans + 1;
            }

            ans += node[pos].cnt;
            pos = rson(pos);
        }
    } 

    inline int rank(int x){
        int pos = root;

        while(true){
            if(x <= node[lson(pos)].siz){
                pos = lson(pos);
                continue;
            }

            x -= node[lson(pos)].siz;

            if(x <= node[pos].cnt)
                return node[pos].key;

            x -= node[pos].cnt;
            pos = rson(pos);
        }
    }

    inline int pre(int key){
        insert(key);
        int p = rank(find(key) - 1);
        Delete(key);

        return p;
    } 

    inline int next(int key){
        insert(key+1);
        int id = find(key+1);
        Delete(key+1);

        int p = rank(id);
        return p;
    }

    #undef lson
    #undef rson 
};

inline int read(){
    int x = 0;
    int p = 1;

    char ch = getchar();

    while(ch < '0' || ch > '9'){
        if(ch == '-')
            p = 0;
        ch = getchar();
    }

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return p ? x : (-x);
}

int n;

int main(){

    #ifdef DEBUG
        freopen("testdata.in", "r", stdin);
        freopen("testdata.out", "w", stdout);
    #endif

    n = read();

    for(int i=0; i<n; i++){
        int op,id;
        op = read();
        id = read();

        if(op==1)
            Scapegoat::insert(id);
        else if(op==2)
            Scapegoat::Delete(id);
        else if(op==3)
            printf("%d\n",Scapegoat::find(id));
        else if(op==4)
            printf("%d\n",Scapegoat::rank(id));
        else if(op==5){
            printf("%d\n",Scapegoat::pre(id));
        }else{
            printf("%d\n",Scapegoat::next(id));
        }
    }
    return 0;
}

Treap

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <ctime>
#include <algorithm>

using namespace std;
const int MAXN = 200000 + 10;
const int INF = 0x3f3f3f3f;

namespace Treap{
    int tree[MAXN][2], siz[MAXN], cnt[MAXN];
    int val[MAXN], rd[MAXN];
    int tot;

    inline void update(int root){
        siz[root] = siz[tree[root][0]] + siz[tree[root][1]] + cnt[root];
    }

    inline void rotate(int &root, int path){
        int son = tree[root][path];

        tree[root][path] = tree[son][path^1];
        tree[son][path^1] = root;

        update(root);
        update(son);

        root = son;
    }

    inline void insert(int &root, int k){
        if(!root){
            root = ++tot;
            cnt[root] = siz[root] = 1;
            val[root] = k;
            rd[root] = rand();
            return;
        }

        siz[root]++;

        if(val[root] == k){
            cnt[root]++;
            return;
        }

        int tmp = val[root] < k;

        insert(tree[root][tmp], k);

        if(rd[root] > rd[tree[root][tmp]])
            rotate(root, tmp);
    }

    inline int find(int root, int k){
        if(!root)
            return 0;

        if(val[root] == k)
            return siz[tree[root][0]] + 1;

        if(val[root] > k)
            return find(tree[root][0], k);

        return find(tree[root][1], k) + siz[tree[root][0]] + cnt[root];
    }

    inline int rank(int root, int k){
        int pos = root;

        while(true){
            if(k <= siz[tree[pos][0]])
                pos = tree[pos][0];
            else if(k > siz[tree[pos][0]] + cnt[pos]){
                k -= siz[tree[pos][0]];
                k -= cnt[pos];
                pos = tree[pos][1];
            }else
                return val[pos];
        }
    } 

    inline int pre(int root, int k){
        if(!root)
            return -INF;

        if(val[root] >= k)
            return pre(tree[root][0], k);

        return max(pre(tree[root][1], k), val[root]);
    } 

    inline int next(int root, int k){
        if(!root)
            return INF;

        if(val[root] <= k)
            return next(tree[root][1], k);

        return min(next(tree[root][0], k), val[root]);
    }

    inline void Delete(int &root, int k){
        if(!root)
            return;

        if(val[root] == k){
            if(cnt[root] > 1){
                cnt[root]--;
                siz[root]--;
                return;
            }

            int path = val[tree[root][0]] > val[tree[root][1]];

            if(tree[root][0] == 0 || tree[root][1] == 0)
                root = tree[root][0] + tree[root][1];
            else{
                rotate(root, path);
                Delete(root, k);
            }
        }else{
            siz[root]--;
            Delete(tree[root][val[root] < k], k);
        }
    }
}


int main(){
    srand(time(NULL));

    int n, root = 0;
    scanf("%d",&n);

    for(int i=0; i<n; i++){
        int op,id;
        scanf("%d%d",&op,&id);

        if(op==1)
            Treap::insert(root, id);
        else if(op==2)
            Treap::Delete(root, id);
        else if(op==3)
            printf("%d\n",Treap::find(root, id));
        else if(op==4)
            printf("%d\n",Treap::rank(root, id));
        else if(op==5){
            printf("%d\n",Treap::pre(root, id));
        }else{
            printf("%d\n",Treap::next(root, id));
        }
    }
    return 0;
}