[题解] ZJOI2019 线段树

题目描述

九条可怜是一个喜欢数据结构的女孩子,在常见的数据结构中,可怜最喜欢的就是线段树。

线段树的核心是懒标记,下面是一个带懒标记的线段树的伪代码,其中 tag 数组为懒标记:

41b9e2f9cde7888ea48c35cd1a767356.png

其中函数 Lson(Node) 表示 Node 的左儿子,Rson(Node) 表示 Node 的右儿子。

现在可怜手上有一棵 [1,n] 上的线段树,编号为 1 。这棵线段树上的所有节点的 tag 均为 0 。接下来可怜进行了 m 次操作,操作有两种:

  • 1 l r ,假设可怜当前手上有 t 棵线段树,可怜会把每棵线段树复制两份(tag 数组也一起复制),原先编号为 i 的线段树复制得到的两棵编号为 2i12i ,在复制结束后,可怜手上一共有 2t 棵线段树。接着,可怜会对所有编号为奇数的线段树进行一次 Modify(root,1,n,l,r)
  • 2 ,可怜定义一棵线段树的权值为它上面有多少个节点 tag1。可怜想要知道她手上所有线段树的权值和是多少。

输入格式

第一行输入两个整数 n,m表示初始区间长度和操作个数。

接下来 m 行每行描述一个操作,输入保证 1lrn

输出格式

对于每次询问,输出一行一个整数表示答案,答案可能很大,对 998244353 取模后输出即可。

样例输入

5 5
2
1 1 3
2
1 3 5
2
C++

样例输出

0
1
6
C++

解题思路

一道可写的简单线段树题

经过 k1 操作后,将有 2k 棵线段树

将线段树合起来看,将线段树中一个节点编号为 itag1 的概率记为 fi,那么答案就为 ifi×2k

考虑 update 操作对答案的影响,我们先回顾这道题 update 的写法

void update(int root, int left, int right, int qleft, int qright){
    int lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;
    if(qleft <= left && right <= qright){
        tag[root] = 1;
        return;
    }

    pushdown(root, left, right);

    if(qleft <= mid)
        update(lson, left, mid, qleft, qright);
    if(mid < qright)
        update(rson, mid + 1, right, qleft, qright);
}
C++

因为有 pushdown 操作,我们考虑用 gi 表示节点编号为 i 和父亲节点 fai 中至少有一个 tag1 的概率

情况一:假如这个区间被完全覆盖

因为每一次修改只修改一半的线段树,没有被修改的对答案的贡献是 fi2gi2 ,被修改的对答案的贡献是 1212
fifi+12gigi+12
情况二:假如进行了 pushdown 操作

被修改的线段树中,进行了 pushdown 的线段树区间 tag 一定是 0,并且从根到这个区间路径上所有节点的 tag 都应该是 0

所以这一半线段树对答案的贡献是 00
fifi2gigi2
情况三:假如进行了 pushdown,但子区间不符合递归 update 的条件

那么显然只对左儿子区间或右儿子区间的 f 造成了影响,对答案的贡献是 fson+gson2
fsonfson+gson2
gson 是不变的

情况四:由于 pushdown 或赋值操作,父区间 tag 被赋值为 1

这对区间的 fi 本身是没有影响的,影响只是 gi,这一部分对答案的贡献是 gi+12
gigi+12
但是这种情况出现次数可能会很多,比如出现一次情况一,就会对其所有子树有影响

但我们发现每个子区间的权值都是乘上 12 后加上 12

因此可以线段树维护 mul 标记和 add 标记解决

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <assert.h>

using namespace std;

const int MAXN = 100000 + 10;
const int MOD = 998244353;
const int INV = 499122177;

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x * 10 + ch - '0';
        ch = getchar();
    }

    return x;
}

int n, m;
long long siz = 1;

namespace SEG{
    long long sum[MAXN << 3], f[MAXN << 3], g[MAXN << 3], add[MAXN << 3], mul[MAXN << 3];

    inline void init(){
        fill(mul, mul + (n << 2) + 1, 1);
    }

    inline void pushup(int root){
        int lson = root << 1, rson = root << 1 | 1;
        sum[root] = (sum[lson] + sum[rson] + f[root]) % MOD;
    }

    inline void pushdown(int root, int left, int right){
        int lson = root << 1, rson = root << 1 | 1;

        if(mul[root] != 1 || add[root] != 0){
            mul[lson] = (mul[root] * mul[lson]) % MOD;
            mul[rson] = (mul[root] * mul[rson]) % MOD;

            add[lson] = (mul[root] * add[lson] + add[root]) % MOD;
            add[rson] = (mul[root] * add[rson] + add[root]) % MOD;

            g[lson] = (mul[root] * g[lson] + add[root]) % MOD;
            g[rson] = (mul[root] * g[rson] + add[root]) % MOD;

            mul[root] = 1;
            add[root] = 0;
        }
    }

    inline void update(int root, int left, int right, int qleft, int qright){
        int lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;

        if(qleft <= left && right <= qright){
            f[root] = (f[root] * INV + INV) % MOD;
            g[root] = (g[root] * INV + INV) % MOD;

            mul[root] = (mul[root] * INV) % MOD;
            add[root] = (add[root] * INV + INV) % MOD;

            pushup(root);
            return;
        }

        pushdown(root, left, right);

        f[root] = (f[root] * INV) % MOD;
        g[root] = (g[root] * INV) % MOD;

        if(qleft <= mid){
            update(lson, left, mid, qleft, qright);
        }else{
            f[lson] = (f[lson] * INV + g[lson] * INV) % MOD;
            pushup(lson);
        }

        if(mid < qright){
            update(rson, mid + 1, right, qleft, qright);
        }else{
            f[rson] = (f[rson] * INV + g[rson] * INV) % MOD;
            pushup(rson);
        }

        pushup(root);
    }

}

int main(){
    freopen("segment.in", "r", stdin);
    freopen("segment.out", "w", stdout);

    n = read();
    m = read();

    for(register int i=1; i<=m; i++){
        int op = read();

        if(op == 2){
            printf("%lld\n", SEG::sum[1] * siz % MOD);
        }else{
            int l = read();
            int r = read();

            SEG::update(1, 1, n, l, r);

            siz = (siz << 1) % MOD;
        }
    }    
    return 0;
}

C++
评论 在线讨论
      加载更多