[题解] JZOJ – 6050 树上四次求和

样例输入

4 4
1 2
2 3
2 4
3 2 1 4
1
2
3
4


样例输出

0
1
6
21


解题思路

\begin{aligned} f_k & = \sum_{l=1}^k \sum_{r=l}^k w(l, r) \\ & = \sum_{l=1}^k \sum_{r=l}^k \sum_{i=l}^r \sum_{j=i}^r dis(a_i, a_j) \end{aligned}

\begin{aligned} f_k – f_{k-1} & = \sum_{i=1}^k w(i, k) \\ & = \sum_{i=1}^k \left[ w(i, k-1) + \sum_{j=i}^k dis(a_j,a_k) \right] \\ &= \sum_{i=1}^k w(i, k-1) + \sum_{i=1}^k \sum_{j=i}^k dis(a_j, a_k) \end{aligned}

$$w(k, k-1) = 0$$

$$\sum_{i=1}^k \sum_{j=i}^k dis(a_j, a_k) = \sum_{i=1}^k i \times dis(a_i, a_k)$$

$$dis(a_i, a_k) = deep_{a_i} + deep_{a_k} – 2 \times deep_{lca(a_i, a_k)}$$

$$\sum_{i=1}^k i \times deep_{a_i}$$

$$\sum_{i=1}^k i \times deep_{a_k}$$

$$\sum_{i=1}^k i \times deep_{lca(a_i, a_k)}$$

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const long long MAXN = 200000 + 10;
const long long MAXM = 400000 + 10;
const long long MOD = 998244353;

long long Head[MAXN], to[MAXM], Next[MAXM], tot = 1;
long long sum[MAXN << 2], add[MAXN << 2];

inline void _add(long long a, long long b){
to[tot] = b;
}

inline void pushdown(long long root, long long left, long long right){
long long lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;

sum[lson] = (sum[lson] + (mid - left + 1) * add[root]) % MOD;
sum[rson] = (sum[rson] + (right - mid) * add[root]) % MOD;

}
}

inline void update(long long root, long long left, long long right, long long qleft, long long qright, long long k){
long long lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;

if(qleft <= left && right <= qright){
sum[root] = (sum[root] + (right - left + 1) * k) % MOD;
return;
}

pushdown(root, left, right);

if(qleft <= mid)
update(lson, left, mid, qleft, qright, k);
if(mid < qright)
update(rson, mid+1, right, qleft, qright, k);

sum[root] = (sum[lson] + sum[rson]) % MOD;
}

inline long long query(long long root, long long left, long long right, long long qleft, long long qright){
long long lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;

if(qleft <= left && right <= qright){
return sum[root];
}

pushdown(root, left, right);

long long ans = 0;

if(qleft <= mid)
ans += query(lson, left, mid, qleft, qright);
if(mid < qright)
ans += query(rson, mid+1, right, qleft, qright);

return ans % MOD;
}

long long n, q;
long long deep[MAXN], topf[MAXN], son[MAXN], fa[MAXN], siz[MAXN];
long long id[MAXN], data[MAXN], top[MAXN], time_stamp;

inline void dfs1(long long x, long long f){
siz[x] = 1;
deep[x] = deep[f] + 1;
fa[x] = f;

for(register long long i=Head[x]; i; i=Next[i]){
long long v = to[i];

if(v != f){
dfs1(v, x);

siz[x] += siz[v];

if(siz[v] > siz[son[x]])
son[x] = v;
}
}
}

inline void dfs2(long long x, long long topf){
id[x] = ++time_stamp;
data[time_stamp] = 0;

top[x] = topf;

if(son[x]){
dfs2(son[x], topf);

for(register long long i=Head[x]; i; i=Next[i]){
long long v = to[i];

if(v != son[x] && v != fa[x]){
dfs2(v, v);
}
}
}
}

inline long long treeQuery(long long x, long long y){
long long ans = 0;

long long f1 = top[x];
long long f2 = top[y];

while(f1 != f2){
if(deep[f1] < deep[f2]){
swap(x, y);
swap(f1, f2);
}

ans = (ans + query(1, 1, n, id[f1], id[x])) % MOD;

x = fa[f1];
f1 = top[x];
}

if(deep[x] > deep[y])
swap(x, y);

ans = (ans + query(1, 1, n, id[x], id[y])) % MOD;

return ans;
}

inline long long treeUpdate(long long x, long long y, long long k){
long long ans = 0;

long long f1 = top[x];
long long f2 = top[y];

while(f1 != f2){
if(deep[f1] < deep[f2]){
swap(x, y);
swap(f1, f2);
}

update(1, 1, n, id[f1], id[x], k);

x = fa[f1];
f1 = top[x];
}

if(deep[x] > deep[y])
swap(x, y);

update(1, 1, n, id[x], id[y], k);
return ans;
}

long long x = 0;
char ch = getchar();

while(ch < '0' || ch > '9')
ch = getchar();

while('0' <= ch && ch <= '9'){
x = x*10 + ch - '0';
ch = getchar();
}

return x;
}

long long a[MAXN];
long long result[MAXN];
long long prefix[MAXN];

signed main(){

for(register long long i=1; i<n; i++){

}

dfs1(1, 0);
dfs2(1, 1);

for(register long long i=1; i<=n; i++)

long long lastans = 0;

for(register long long i=1; i<=n; i++){
prefix[i] = (prefix[i-1] + (i * deep[a[i]] % MOD)) % MOD;
}

for(register long long i=1; i<=n; i++){
long long valadd = ((lastans + (((i * (i-1)) / 2) % MOD * deep[a[i]]) % MOD) % MOD + prefix[i-1]) % MOD;
valadd -= 2 * treeQuery(1, a[i]);