[模板] 二逼平衡树(线段树套Splay)

题目描述

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

  • 查询 $k$ 在区间内的排名
  • 查询区间内排名为 $k$ 的值
  • 修改某一位值上的数值
  • 查询 $k$ 在区间内的前驱(前驱定义为严格小于 $x$,且最大的数,若不存在输出 $-2147483647$ )
  • 查询 $k$ 在区间内的后继(后继定义为严格大于 $x$,且最小的数,若不存在输出 $2147483647$ )

输入格式

第一行两个数 $n,m$ 表示长度为 $n$ 的有序序列和 $m$ 个操作

第二行有 $n$ 个数,表示有序序列

下面有 $m$ 行,$opt$ 表示操作标号

  • 若 $opt=1$ 则为操作 $1$,之后有三个数 $l,r,k$ 表示查询 $k$ 在区间 $[l,r]$ 的排名
  • 若 $opt=2$ 则为操作 $2$,之后有三个数 $l,r,k$ 表示查询区间 $[l,r]$ 内排名为k的数
  • 若 $opt=3$ 则为操作 $3$,之后有两个数 $pos,k$ 表示将 $pos$ 位置的数修改为 $k$
  • 若 $opt=4$ 则为操作 $4$,之后有三个数 $l,r,k$ 表示查询区间 $[l,r]$ 内 $k$ 的前驱
  • 若 $opt=5$ 则为操作 $5$,之后有三个数 $l,r,k$ 表示查询区间 $[l,r]$ 内 $k$ 的后继

输出格式

对于操作 $1, 2, 4, 5$ 各输出一行,表示查询结果

样例输入

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

样例输出

2
4
3
4
9

数据范围及提示

$n, m \leq 5 \times 10^4$,保证有序序列所有值在任何时刻满足 $[0, 10^8]$

解题思路

线段树套平衡树模板题,在这里平衡树选用了 $Splay$

对于在线段树中的每一个区间 $[l, r]$,都用一棵平衡树来维护区间信息

对于操作 $1$,在线段树内查询 $[l, r]$ 对应的 $Splay$ 中比 $k$ 小的数的个数,相加即可。(输出答案时需 $+1$)

对于操作 $3$,在线段树内对应修改即可(先在 $Splay$ 中删除旧值,再插入新值)

对于操作 $4$,在线段树递归时不断取最大值即可

对于操作 $5$,在线段树递归时不断取最小值即可

上面的操作全都是 $O(log^2 n)$复杂度的,对于操作 $2$,它不满足区间累加性,无法通过简单操作得到结果

我们考虑二分一个值 $x$,判断 $x$ 在区间内的排名是多少,即可得到答案

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 50000;
const int SIZ = 50000 * 25;
const int INF = 2147483647;

int n, m;
int data[MAXN];

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
}

namespace Splay{
    int tree[SIZ][2], fa[SIZ], tot;
    int num[SIZ], cnt[SIZ], siz[SIZ];

    inline void clear(int x){
        tree[x][0] = tree[x][1] = 0;
        fa[x] = num[x] = cnt[x] = siz[x] = 0;
    }

    inline int getpath(int x){
        return tree[fa[x]][1] == x;
    }

    inline void update(int x){
        siz[x] = siz[tree[x][0]] + siz[tree[x][1]] + cnt[x];
    }

    inline void rotate(int x){
        int father = fa[x];
        int grandfather = fa[father];
        int path = getpath(x);

        tree[father][path] = tree[x][path^1];
        fa[tree[father][path]] = father;

        tree[x][path^1] = father;
        fa[father] = x;

        fa[x] = grandfather;

        if(grandfather){
            tree[grandfather][tree[grandfather][1] == father] = x;
        } 

        update(father);
        update(x);
    }

    inline void splay(int &root, int x){
        for(register int f; (f = fa[x]); rotate(x)){
            if(fa[f]){
                rotate(getpath(f) == getpath(x) ? f : x);
            }
        }
        root = x;
    }

    inline void insert(int &root, int x){
        if(!root){
            root = ++tot;
            num[tot] = x;
            cnt[tot] = siz[tot] = 1;
            return;
        }

        int pos = root;
        int f;

        while(true){
            if(num[pos] == x){
                cnt[pos]++;
                siz[pos]++;
                splay(root, pos);
                return;
            }

            f = pos;
            pos = tree[pos][num[pos] < x];

            if(!pos){
                pos = ++tot;

                cnt[pos] = siz[pos] = 1;
                num[pos] = x;

                fa[pos] = f;
                tree[f][num[f] < x] = tot;

                splay(root, pos);
                return;
            }
        }
    }

    inline int pre(int &root){
        int pos = tree[root][0];

        while(tree[pos][1])
            pos = tree[pos][1];

        return pos;
    }

    inline int pre(int &root, int x){
        int pos = root;
        int ans = -INF;

        while(pos){
            if(x > num[pos]){
                ans = max(ans, num[pos]);
                pos = tree[pos][1];
                continue;
            }

            pos = tree[pos][0];
        }

        return ans;
    }

    inline int next(int &root){
        int pos = tree[pos][1];

        while(tree[pos][0])
            pos = tree[pos][0];

        return pos;
    }

    inline int next(int &root, int x){
        int pos = root;
        int ans = INF;

        while(pos){
            if(x < num[pos]){
                ans = min(ans, num[pos]);
                pos = tree[pos][0];
                continue;
            }

            pos = tree[pos][1];
        }

        return ans;
    }

    inline int find(int &root, int x){
        int ans = 0;
        int pos = root;

        while(true){
            if(x < num[pos]){
                pos = tree[pos][0];
                continue;
            }

            ans += siz[tree[pos][0]];

            if(x == num[pos]){
                splay(root, pos);
                return ans + 1;
            }

            ans += cnt[pos];
            pos = tree[pos][1];
        }
    }

    inline int rank(int &root, int x){
        int pos = root;
        int ans = 0;

        while(pos){
            if(x < num[pos]){
                pos = tree[pos][0];
                continue;
            }

            ans += siz[tree[pos][0]];

            if(x == num[pos]){
                splay(root, pos);
                return ans;
            }

            if(x > num[pos]){
                ans += cnt[pos];
                pos = tree[pos][1];
            }
        }

        return ans;
    }

    inline void Delete(int &root, int x){
        find(root, x);

        if(cnt[root] > 1){
            cnt[root]--;
            siz[root]--;
            return;
        }

        if(!tree[root][0] && !tree[root][1]){
            clear(root);
            root = 0;
            return;
        }

        if(!tree[root][0] && tree[root][1]){
            int oldroot = root;

            root = tree[root][1];
            fa[root] = 0;

            clear(oldroot);
            return;
        }

        if(tree[root][0] && !tree[root][1]){
            int oldroot = root;

            root = tree[root][0];
            fa[root] = 0;

            clear(oldroot);
            return;
        }

        int oldroot = root;
        int Pre = pre(root);

        splay(root, Pre);

        tree[root][1] = tree[oldroot][1];
        fa[tree[root][1]] = root;

        fa[root] = 0;

        clear(oldroot);
        update(root);

        return;
    }
}

namespace SEG{
    int rt[MAXN << 2];

    inline void build(int root, int left, int right){
        int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;

        for(register int i=left; i<=right; i++){
            Splay::insert(rt[root], data[i]);
        }

        if(left == right)
            return;

        if(left <= mid)
            build(lc, left, mid);
        if(mid < right)
            build(rc, mid+1, right);
    }

    inline void update(int root, int left, int right, int last, int now, int pos){
        int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;

        Splay::insert(rt[root], now);
        Splay::Delete(rt[root], last);

        if(left == right)
            return;

        if(pos <= mid)
            update(lc, left, mid, last, now, pos);
        if(mid < pos)
            update(rc, mid+1, right, last, now, pos);
    }

    inline int queryRank(int root, int left, int right, int qleft, int qright, int num){
        int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;

        if(qleft <= left && right <= qright){
            return Splay::rank(rt[root], num);
        }

        int ans = 0;

        if(qleft <= mid)
            ans += queryRank(lc, left, mid, qleft, qright, num);
        if(mid < qright)
            ans += queryRank(rc, mid+1, right, qleft, qright, num);

        return ans; 
    }

    inline int queryPre(int root, int left, int right, int qleft, int qright, int num){
        int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;

        if(qleft <= left && right <= qright){
            return Splay::pre(rt[root], num);
        }

        int ans = -INF;

        if(qleft <= mid)
            ans = max(ans, queryPre(lc, left, mid, qleft, qright, num));

        if(mid < qright)
            ans = max(ans, queryPre(rc, mid+1, right, qleft, qright, num));

        return ans;
    }

    inline int queryNext(int root, int left, int right, int qleft, int qright, int num){
        int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;

        if(qleft <= left && right <= qright){
            return Splay::next(rt[root], num);
        }

        int ans = INF;

        if(qleft <= mid)
            ans = min(ans, queryNext(lc, left, mid, qleft, qright, num));

        if(mid < qright)
            ans = min(ans, queryNext(rc, mid+1, right, qleft, qright, num));

        return ans;
    }

    inline int queryKth(int root, int left, int right, int num){
        int l = 0;
        int r = 1e8 + 10;

        while(l < r){
            int mid = ((l + r) >> 1) + 1;

            if(queryRank(1, 1, n, left, right, mid) < num)
                l = mid;
            else
                r = mid - 1;
        }

        return l;
    }
}


int main(){
    n = read();
    m = read();

    for(register int i=1; i<=n; i++)
        data[i] = read();

    SEG::build(1, 1, n);

    for(register int i=1; i<=m; i++){
        int op = read();
        int a = read();
        int b = read();

        if(op == 3){
            SEG::update(1, 1, n, data[a], b, a);
            data[a] = b;
        }else{
            int c = read();

            if(op == 1){
                printf("%d\n", SEG::queryRank(1, 1, n, a, b, c) + 1);
            }else if(op == 2){
                printf("%d\n", SEG::queryKth(1, a, b, c));
            }else if(op == 4){
                printf("%d\n", SEG::queryPre(1, 1, n, a, b, c));
            }else{
                printf("%d\n", SEG::queryNext(1, 1, n, a, b, c));
            }
        }
    }
}