[模板] CDQ分治

CDQ 分治是一种处理系列操作问题的一种离线算法,常数较小,可以代替一些因为内存原因无法使用的数据结构。

归并排序求逆序对

给一列数 $a_1,a_2,…,a_n$,求它的逆序对数,即有多少个有序对 $(i,j)$ ,使得 $i<j$ 且 $a_i>a_j$

对于两个已经排序好的区间,在合并过程中,考虑右区间对左区间的贡献。当从右区间取出一个数时,代表着左区间还没有被选中的数全部比它大,统计答案。

这就是 cdq 分治的一个重要思想

inline void merge(int l, int r){
    if(l == r)
        return;

    int mid = (left + right) >> 1;

    merge(l, mid);
    merge(mid+1, r);

    int i = l;
    int j = mid+1;

    for(register int k=l; k<=r; k++){
        if((i <= mid && data[i] <= data[j]) || j > r)
            tmp[k] = data[i++];
        else{
            tmp[k] = data[j++];
            cnt += mid - i + 1;
        }
    }

    for(register int k=l; k<=r; k++)
        data[k] = tmp[k];
}

二维偏序问题

给定 $n$ 个有序对 $(a, b)$,求对于每个 $(a, b)$,满足 $a'<a$ 且 $b'<b$ 的有序对 $(a’, b’)$ 的个数

我们考虑如何转化成为求逆序对的过程,我们对 $a$ 进行排序,我们就可以忽略 $a$ 的影响,这个时候问题就转化成为了求顺序对的个数

题目描述

给定一个 $N$ 个元素的序列 $a$ ,初始值全部为 $0$,对这个序列进行以下两种操作:
操作$1$:格式为 $1 \ x \ k$,把位置 $x$ 的元素加上 $k$(位置从 $1$ 标号到 $N$)。
操作$2$:格式为 $2 \ x \ y$,求出区间 $[x,y]$ 内所有元素的和。

解题分析

我们将每一个操作都用一个二元组表示,然后我们对这个操作序列进行 CDQ 分治。

我们在这个二元组中再维护一个附加信息 $type$,用来表示这个操作是查询操作还是修改操作。

修改操作需要被拆分为两个二元组,一个用以查询 $l-1$ 的前缀和,另一个用以查询 $r$的前缀和

每一次合并,我们都只去考虑左区间的修改对右区间的查询的影响。

修改操作优先于查询操作

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 1500000 + 10;

inline int read(){
    int x = 0;
    int p = 1;

    char ch = getchar();

    while(ch < '0' || ch > '9'){
        if(ch == '-')
            p = 0;
        ch = getchar();
    }

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return p ? x : (-x);
}

struct Statement{
    int type;
    int id;
    int val;

    bool operator < (const Statement x) const{
        if(id == x.id)
            return type < x.type;
        return id < x.id;
    }
}data[MAXN];

int n, m;
int cnt, query_cnt;
int ans[MAXN];
Statement tmp[MAXN];

inline void cdq(int l, int r){
    int mid = (l + r) >> 1;

    if(r == l)
        return;

    cdq(l, mid);
    cdq(mid + 1, r);

    int a = l, b = mid + 1;
    int sum = 0;

    for(register int k=l; k<=r; k++){
        if((a <= mid && data[a] < data[b]) || b > r){
            if(data[a].type == 1){
                sum += data[a].val; 
            }
            tmp[k] = data[a++];         
        }else{
            if(data[b].type == 2){
                ans[data[b].val] -= sum;
            }else if(data[b].type == 3){
                ans[data[b].val] += sum;
            }
        tmp[k] = data[b++];
        }
    }

    for(register int k=l; k<=r; k++){
        data[k] = tmp[k];
    }

}
int main(){
    n = read();
    m = read();

    for(register int i=1; i<=n; i++){
        data[++cnt].type = 1;
        data[cnt].id = i;
        data[cnt].val = read();
    }

    for(register int i=1; i<=m; i++){
        int k = read();

        if(k == 1){
            data[++cnt].type = 1;   
            data[cnt].id = read();
            data[cnt].val = read();
        }else{
            int x = read();
            int y = read();

            data[++cnt].type = 2;   
            data[cnt].id = x - 1;
            data[cnt].val = ++query_cnt;

            data[++cnt].type = 3;   
            data[cnt].id = y;
            data[cnt].val = query_cnt;                      
        }
    }

    cdq(1, cnt);

    for(register int i=1; i<=query_cnt; i++){
        printf("%d\n", ans[i]);
    }

    return 0;
}

三维偏序问题

给定 $N$ 个有序三元组 $(a,b,c)$,求对于每个三元组 $(a,b,c)$,有多少个三元组 $(a2,b2,c2)$ 满足 $a2<a$ 且 $b2<b$ 且 $c2<c$。

本质上依旧是降维,现将 $a$ 元素排序,即可忽略 $a$ 元素的影响,然后 $b$ 从小到大顺序归并,但是无法直接求解,需要将 $c$ 存入树状数组中查询。

对于左区间中的三元组 $(a,b,c)$,根据 $c$ 在树状数组中查询修改。

题目描述

花果山国土辽阔,地大物博……但是最近却在闹蝗灾…..
我们可以把花果山国土当成一个 $W \times W$ 的矩阵,你会收到一些诸如$(X,Y,Z)$的信息,代
表 $(X,Y)$这个点增多了 $Z$ 只蝗虫,而由于孙悟空走后花果山政府机关比较臃肿,为了批复
消灭蝗虫的请求需要询问一大堆的问题……每个询问形如 $(X1,Y1,X2,Y2)$,询问在 $(X1,Y1,X2,Y2)$
范围内有多少蝗虫(请注意询问不会改变区域内的蝗虫数),你作为一个花果山的猴几,需
要编一个程序快速的回答所有的询问。
注意:花果山一开始没有蝗虫。

解题分析

数据范围较小的情况可以用二维树状数组解决,但本题显然不行。

依旧将每个操作用三元组表示,我们也依旧需要在三元组中记录一个操作类型信息,但此时查询请求比较复杂,我们用一个参数 $w$ 描述答案时加还是减。

修改操作依旧优先于查询操作

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;

const int MAXN = 800000 + 10;
const int MAXM = 500000 + 10;

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }
    return x;
}

int n, m;
int cnt, query_cnt;
int ans[MAXN];

struct Statement{
    int type;
    int x, y;
    int w;
    int val;

    bool operator < (const Statement a) const{
        if(x != a.x || y != a.y){
            if(x == a.x)
                return y < a.y;
            return x < a.x;
        }
        return type < a.type;
    }
}data[MAXN], tmp[MAXN];

struct BIT{
    int num[MAXM];

    inline int lowbit(int x){
        return x & (-x);
    }

    inline void clear(int a){
        while(a <= n){
            if(num[a]){
                num[a] = 0;
            }else{
                break;
            }

            a += lowbit(a);
        }
    }

    inline void update(int a, int b){
        while(a <= n){
            num[a] += b;
            a += lowbit(a);
        }
    }

    inline int query(int a){
        int ans = 0;

        while(a){
            ans += num[a];
            a -= lowbit(a);
        }

        return ans;
    }
}tree;

inline void cdq(int l, int r){
    if(l == r)
        return;

    int mid = (l + r) >> 1;

    cdq(l, mid);
    cdq(mid + 1, r);

    int a = l, b = mid + 1;

    for(register int k=l; k<=r; k++){
        if((a <= mid && data[a] < data[b]) || b > r){
            if(data[a].type == 1){
                tree.update(data[a].y, data[a].val);
            }

            tmp[k] = data[a++];
        }else{
            if(data[b].type == 2){
                ans[data[b].val] += data[b].w * tree.query(data[b].y);
            }

            tmp[k] = data[b++];
        }
    }

    for(register int k=l; k<=r; k++){
        tree.clear(tmp[k].y);
        data[k] = tmp[k];
    }

}
int main(){

    n = read();
    m = read();

    for(register int i=1; i<=m; i++){
        int k = read();

        if(k == 1){
            data[++cnt].type = 1;
            data[cnt].x = read();
            data[cnt].y = read();
            data[cnt].val = read();
        }else{
            int a = read();
            int b = read();
            int c = read();
            int d = read();

            data[++cnt] = (Statement){2, c, d, 1, ++query_cnt};
            data[++cnt] = (Statement){2, a-1, b-1, 1, query_cnt};
            data[++cnt] = (Statement){2, c, b-1, -1, query_cnt};
            data[++cnt] = (Statement){2, a-1, d, -1, query_cnt};
        }
    }

    cdq(1, cnt);

    for(register int i=1; i<=query_cnt; i++){
        printf("%d\n", ans[i]);
    }


    return 0;
}