[模板] 可持久化线段树 & 主席树

可持久化线段树是一种能支持访问历史版本的线段树,能实现基于某一个历史版本的操作。主席树和可持久化线段树一般是等价的,但更多时候主席树一般是一棵权值(值域)线段树

实现

先来思考如何暴力实现可持久化
对于每一步操作,我们都将之前的线段树拷贝一份,这样就可以实现可持久化了。

空间复杂度 $O(mn\ log \ n)$,非常的小。

我们来回顾一下线段树的操作,无论是区间修改还是区间查询,对应在线段树上的区间总是 $O(log \ n)$ 级别的。

也就是说,对于每一次操作,我们可以只将覆盖到了的区间拷贝一份,再对他们进行修改。

这个操作非常类似于动态开点线段树,较普通线段树相当于多记录了一个左儿子和右儿子的编号(不再是 $2n$ 和 $2n+1$ 的关系了)

以下是一个更新操作的框架

inline void update(int &root, int oldroot, int left, int right, itn qleft, int qright, ...){

    root = ++tot;

    lson[root] = lson[oldroot];
    rson[root] = rson[oldroot];

    ... //其他数值的复制、修改等

    if(left == right){
        ... //修改操作等
        return;
    }

    int mid = (left + right) >> 1;

    if(qleft <= mid)
        update(lson[root], lson[oldroot], left, mid, num);

    if(mid < qright)
        update(rson[root], rson[oldroot], mid+1, right, num);

    ... //回溯操作等
}

例题

主席树实现静态区间 $k$ 小值

我们将原先大小为 $n$ 的数组一个个插入主席树中,相当于产生了 $n$ 个版本的权值线段树。

对于区间 $[l, r]$,相当于找到版本 $l-1$(区间 $[1, l – 1]$ ) 和版本 $r$ (区间 $[1, r]$ ),区间 $[l, r]$ 内元素个数即为它们的差,找到 $k$ 小值即可

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = (200000 + 10) << 5;

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
}

int sum[MAXN], lson[MAXN], rson[MAXN], tot;
int root[MAXN]; 

inline void build(int &root, int left, int right){
    root = ++tot;

    if(left == right)
        return;

    int mid = (left + right) >> 1;

    if(left <= mid)
        build(lson[root], left, mid);

    if(mid < right)
        build(rson[root], mid+1, right);    
}

inline void update(int &root, int oldroot, int left, int right, int num){
    root = ++tot;

    sum[root] = sum[oldroot] + 1;
    lson[root] = lson[oldroot];
    rson[root] = rson[oldroot];

    if(left == right){
        return;
    }

    int mid = (left + right) >> 1;

    if(num <= mid)
        update(lson[root], lson[oldroot], left, mid, num);

    if(mid < num)
        update(rson[root], rson[oldroot], mid+1, right, num);
}

inline int query(int leftroot, int rightroot, int left, int right, int num){
    int mid = (left + right) >> 1;

    if(left == right)
        return left;

    if(num <= sum[lson[rightroot]] - sum[lson[leftroot]]){
        return query(lson[leftroot], lson[rightroot], left, mid, num);
    }else{
        num -= sum[lson[rightroot]] - sum[lson[leftroot]];
        return query(rson[leftroot], rson[rightroot], mid+1, right, num);
    }
}

int num[MAXN];
int que[MAXN];
int n, m;

int main(){
    n = read();
    m = read();

    for(register int i=1; i<=n; i++)
        num[i] = que[i] = read();

    sort(que+1, que+n+1);
    int cnt = unique(que+1, que+n+1) - que;

    for(register int i=1; i<=n; i++){
        num[i] = lower_bound(que+1, que+cnt, num[i]) - que;
    }

    build(root[0], 1, n);

    for(register int i=1; i<=n; i++)
        update(root[i], root[i-1], 1, n, num[i]);

    for(register int i=1; i<=m; i++){
        int a = read();
        int b = read();
        int c = read();

        printf("%d\n", que[query(root[a-1], root[b], 1, n, c)]);
    }
    return 0;
}