[题解] JZOJ – 6050 树上四次求和

题目背景

对于一棵树,我们定义 $dis(i, j)$ 为节点 $i$ 和 $j$ 之间最短路径上的边数。

对于一个长度为 $n$ 的序列 $a$,我们定义 $w(l, r)$ 为 $\sum_{i=l}^r \sum_{j=i}^r dis(a_i, a_j)$

题目描述

给你一棵 $n$ 个点的树以及一个 $1-n$ 的排列 $a$,有 $q$ 次询问,每次给出 $k$,求 $\sum_{i=1}^k \sum_{j=i}^k w(i, j)$ 对 $998244353$ 取模的值

输入格式

第一行两个正整数 $n$ 和 $q$
接下来 $n-1$ 行,每行两个正整数 $u$ 和 $v$,表示 $u$ 和 $v$ 之间有一条树边
接下来一行 $n$ 个数字描述排列 $a$
接下来 $q$ 行每行一个正整数 $k_i$ 表示询问

输出格式

输出 $q$ 行,第 $i$ 行表示询问 $k_i$ 的答案

样例输入

4 4
1 2
2 3
2 4
3 2 1 4
1
2
3
4

样例输出

0
1
6
21

数据范围及提示

对于前 $30 \%$ 的数据,$n \leq 1000$
对于另 $20 \%$ 的数据,$q_i = n$
对于另 $20 \%$ 的数据,第 $i$ 条边连接点 $i$ 与点 $i+1$
对于 $100 \%$ 的数据,$n, q \leq 10^5, u, v, k_i \leq n$

解题思路

题目要求

$$
\begin{aligned}
f_k & = \sum_{l=1}^k \sum_{r=l}^k w(l, r) \\
& = \sum_{l=1}^k \sum_{r=l}^k \sum_{i=l}^r \sum_{j=i}^r dis(a_i, a_j)
\end{aligned}
$$

考虑 $k$ 从 $k-1$ 转移到 $k$ 时答案的增加量

$$
\begin{aligned}
f_k – f_{k-1} & = \sum_{i=1}^k w(i, k) \\
& = \sum_{i=1}^k \left[ w(i, k-1) + \sum_{j=i}^k dis(a_j,a_k) \right] \\
&= \sum_{i=1}^k w(i, k-1) + \sum_{i=1}^k \sum_{j=i}^k dis(a_j, a_k)
\end{aligned}
$$

因为

$$
w(k, k-1) = 0
$$

所以式子的前半部分就是上一次从 $k-2$ 转移到 $k-1$ 的增加量

考虑将和式展开,发现对于每一个 $dis(a_j, a_k)$ , 都计算了 $j$ 次

$$
\sum_{i=1}^k \sum_{j=i}^k dis(a_j, a_k) = \sum_{i=1}^k i \times dis(a_i, a_k)
$$

这样的算法是 $O(n^2)$ 的, 再考虑将 $dis(a_i, a_k)$ 展开

$$
dis(a_i, a_k) = deep_{a_i} + deep_{a_k} – 2 \times deep_{lca(a_i, a_k)}
$$

那么式子就被划分成了三部分

$$
\sum_{i=1}^k i \times deep_{a_i}
$$

可以通过预处理前缀和,$O(n)$ 预处理,$O(1)$ 查询

$$
\sum_{i=1}^k i \times deep_{a_k}
$$

一共出现了 $\frac{k \times (k-1)}{2}$ 次,可以 $O(1)$ 计算

问题转化为如何快速求解下面这个式子

$$
\sum_{i=1}^k i \times deep_{lca(a_i, a_k)}
$$

它可以通过树链剖分在 $O(\log^2 n)$ 的时间内求解,也可以使用 $\mathrm{lct}$ 在 $O(\log n)$ 的时间内求解

对于每一个 $a_i$,它与 $a_x$ 的 $\mathrm{lca}$ 一定在它到根节点的路径上,因此将 $a_i$ 到根节点上的每一个节点权值加上 $i$,统计 $a_k$ 到根节点的权值和即可

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const long long MAXN = 200000 + 10;
const long long MAXM = 400000 + 10;
const long long MOD = 998244353;

long long Head[MAXN], to[MAXM], Next[MAXM], tot = 1;
long long sum[MAXN << 2], add[MAXN << 2];

inline void _add(long long a, long long b){
    to[tot] = b;
    Next[tot] = Head[a];
    Head[a] = tot++;
}

inline void pushdown(long long root, long long left, long long right){
    if(add[root]){
        long long lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;

        add[lson] = (add[lson] + add[root]) % MOD;
        add[rson] = (add[rson] + add[root]) % MOD;

        sum[lson] = (sum[lson] + (mid - left + 1) * add[root]) % MOD;
        sum[rson] = (sum[rson] + (right - mid) * add[root]) % MOD;

        add[root] = 0;
    }
}

inline void update(long long root, long long left, long long right, long long qleft, long long qright, long long k){
    long long lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;

    if(qleft <= left && right <= qright){
        add[root] = (add[root] + k) % MOD;
        sum[root] = (sum[root] + (right - left + 1) * k) % MOD;
        return;
    }

    pushdown(root, left, right);

    if(qleft <= mid)
        update(lson, left, mid, qleft, qright, k);
    if(mid < qright)
        update(rson, mid+1, right, qleft, qright, k);

    sum[root] = (sum[lson] + sum[rson]) % MOD;
}

inline long long query(long long root, long long left, long long right, long long qleft, long long qright){
    long long lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;

    if(qleft <= left && right <= qright){
        return sum[root];
    }

    pushdown(root, left, right);

    long long ans = 0;

    if(qleft <= mid)
        ans += query(lson, left, mid, qleft, qright);
    if(mid < qright)
        ans += query(rson, mid+1, right, qleft, qright);

    return ans % MOD;
}

long long n, q;
long long deep[MAXN], topf[MAXN], son[MAXN], fa[MAXN], siz[MAXN];
long long id[MAXN], data[MAXN], top[MAXN], time_stamp;

inline void dfs1(long long x, long long f){
    siz[x] = 1;
    deep[x] = deep[f] + 1;
    fa[x] = f;

    for(register long long i=Head[x]; i; i=Next[i]){
        long long v = to[i];

        if(v != f){
            dfs1(v, x);

            siz[x] += siz[v];

            if(siz[v] > siz[son[x]])
                son[x] = v;
        }
    }
}

inline void dfs2(long long x, long long topf){
    id[x] = ++time_stamp;
    data[time_stamp] = 0;

    top[x] = topf;

    if(son[x]){
        dfs2(son[x], topf);

        for(register long long i=Head[x]; i; i=Next[i]){
            long long v = to[i];

            if(v != son[x] && v != fa[x]){
                dfs2(v, v);
            }
        }
    }
}

inline long long treeQuery(long long x, long long y){
    long long ans = 0;

    long long f1 = top[x];
    long long f2 = top[y];

    while(f1 != f2){
        if(deep[f1] < deep[f2]){
            swap(x, y);
            swap(f1, f2);
        }

        ans = (ans + query(1, 1, n, id[f1], id[x])) % MOD;

        x = fa[f1];
        f1 = top[x];
    }

    if(deep[x] > deep[y])
        swap(x, y);

    ans = (ans + query(1, 1, n, id[x], id[y])) % MOD;

    return ans;
} 


inline long long treeUpdate(long long x, long long y, long long k){
    long long ans = 0;

    long long f1 = top[x];
    long long f2 = top[y];

    while(f1 != f2){
        if(deep[f1] < deep[f2]){
            swap(x, y);
            swap(f1, f2);
        }

        update(1, 1, n, id[f1], id[x], k);

        x = fa[f1];
        f1 = top[x];
    }

    if(deep[x] > deep[y])
        swap(x, y);

    update(1, 1, n, id[x], id[y], k);
    return ans;
} 

inline long long read(){
    long long x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
}

long long a[MAXN];
long long result[MAXN];
long long prefix[MAXN];

signed main(){
    n = read();
    q = read();

    for(register long long i=1; i<n; i++){
        long long a = read();
        long long b = read();

        _add(a, b);
        _add(b, a);
    }

    dfs1(1, 0);
    dfs2(1, 1);

    for(register long long i=1; i<=n; i++)
        a[i] = read();

    long long lastans = 0;

    for(register long long i=1; i<=n; i++){
        prefix[i] = (prefix[i-1] + (i * deep[a[i]] % MOD)) % MOD;
    }

    for(register long long i=1; i<=n; i++){
        long long valadd = ((lastans + (((i * (i-1)) / 2) % MOD * deep[a[i]]) % MOD) % MOD + prefix[i-1]) % MOD;
        valadd -= 2 * treeQuery(1, a[i]);
        valadd = (valadd % MOD + MOD) % MOD;

        lastans = valadd;
        result[i] = (result[i-1] + valadd) % MOD;
        treeUpdate(1, a[i], i);
    }

    for(register long long i=1; i<=q; i++){
        long long k = read();
        printf("%lld\n", result[k]);
    }
}