[模板] 树链剖分

树链剖分就是将树分割成多条链,然后利用数据结构(线段树、树状数组等)来维护这些链。

树链剖分可以用来解决两点间路径相关的查询,修改问题。

树链剖分基本概念

重结点:子树结点数目最多的结点
轻节点:父亲节点中除了重结点以外的结点
重边:父亲结点和重结点连成的边
轻边:父亲节点和轻节点连成的边
重链:由多条重边连接而成的路径
轻链:由多条轻边连接而成的路径

比如上面这幅图中,用黑线连接的结点都是重结点,其余均是轻结点,2-11、1-14就是重链,其他就是轻链,用红点标记的就是该结点所在链的起点。

 

定义一些全局数组

son[] : 记录重儿子信息
siz[] : 记录子树个数
top[] : 记录所在链的链顶
deep[] : 记录节点深度
fa[] : 记录节点父亲
id[] : 记录dfs序

 

Step 1 : 进行第一次dfs

第一次dfs的目的是处理出每一个节点的重儿子,深度,父亲,子树大小

void dfs1(int u, int f){

    deep[u] = deep[f] + 1; //计算每个节点的深度 
    siz[u] = 1; //计算每个节点子树大小 
    fa[u] = f; //计算每个节点的父亲 

    for(int i=0; i<map[u].size(); i++){
        int v = map[u][i];
        if(v != f){
            dfs1(v,u); //递归处理子节点
            siz[u] += siz[v];

            if(siz[v] > siz[son[u]]) //如果这个儿子子树节点数目更多,更新重儿子
                son[u] = v;
        }
    }
}

Step 2: 进行第二次dfs

第二次dfs的目的是处理出链顶, dfs序, (同时也可以处理出线段树build所需要的数组)

void dfs2(int u,int topf){
    id[u] = ++time_stamp; //处理dfs序
    top[u] = topf; //处理链顶

    if(son[u]){
        dfs2(son[u],topf); //重儿子,重链链顶延续

        for(int i=0; i<map[u].size(); i++){
            int v = map[u][i];
            if(v != son[u] && v!=fa[u])
                dfs2(v,v); //重链链顶链顶重新设置

        }
    }
}

Step 3: 线段树操作

这里的线段树以 BZOJ 1036 为例

void pushup(int o){
    int lc = o << 1;
    int rc = o << 1 | 1;

    num[o] = max(num[lc],num[rc]);
    sum[o] = sum[lc] + sum[rc];
}

void update(int o, int l, int r, int p, int k){
    int lc = o << 1, rc = o << 1 | 1, mid = (l+r) >>1;

    if(l == r){
        sum[o] = num[o] = k;
        return;
    }

    if(p <= mid)
        update(lc, l, mid, p, k);
    if(mid < p)
        update(rc, mid+1, r, p, k);

    pushup(o);
}

int getMax(int o, int l, int r, int ql, int qr){
    int lc = o << 1, rc = o <<1 | 1, mid=(l+r)>>1, ans=-0x3f3f3f3f;

    if(ql <= l && r <= qr)
        return num[o];

    if(ql <= mid)
        ans = max(ans,getMax(lc, l, mid, ql, qr));
    if(mid < qr)
        ans = max(ans,getMax(rc, mid+1, r, ql, qr));

    return ans;
}

int getSum(int o, int l, int r, int ql, int qr){
    int lc = o << 1, rc = o <<1 | 1, mid = (l+r) >> 1, ans = 0;

    if(ql <=l && r <= qr)
        return sum[o];

    if(ql <= mid)
        ans += getSum(lc, l, mid, ql, qr);

    if(mid < qr)
        ans += getSum(rc, mid+1, r, ql, qr);

    return ans;
}

Step 4: 树链剖分操作

查询操作,非常类似于倍增求LCA,不过这里直接跳转到top的父亲节点,(但是轻链的top就是自己)。需要注意的是,每次循环只能跳一次,并且让top结点深的那个来跳到top的位置,避免两个一起跳从而错过。

本题只涉及到了路径查询,对于路径修改,和查询操作非常类似,只是将向线段树查询链信息改为向线段树修改链信息。

int findMax(int u, int v){
    int f1 = top[u], f2= top[v];
    int ans = -0x3f3f3f3f;

    while(f1 != f2){
        if(deep[f1] < deep[f2]){
            swap(f1,f2);
            swap(u,v);
        }

        ans = max(ans, getMax(1, 1, n, id[f1], id[u]));

        u = fa[f1];
        f1 = top[u];
    }

    if(deep[u] > deep[v])
        swap(u,v);

    ans = max(ans, getMax(1, 1, n, id[u], id[v]));

    return ans;
}

int findSum(int u, int v){
    int f1 = top[u], f2=top[v];

    int ans = 0;

    while(f1 != f2){
        if(deep[f1] < deep[f2]){
            swap(u,v);
            swap(f1,f2);
        }

        ans += getSum(1, 1, n, id[f1], id[u]);

        u = fa[f1];
        f1 = top[u];
    }

    if(deep[u] > deep[v])
        swap(u,v);

    ans += getSum(1, 1, n, id[u], id[v]);

    return ans;
}

Step 5:主函数

因为这里涉及到了一些函数的特殊调用,特将主函数也一同附上

int main(){ 
    scanf("%d",&n);

    for(int i=0; i<n-1; i++){
        int u,v;
        scanf("%d%d",&u,&v);
        map[u].push_back(v);
        map[v].push_back(u);
    }

    deep[1] = 1;
    dfs1(1,0);
    dfs2(1,1);

    for(int i=1; i<=n; i++){
        scanf("%d",&w[i]);
        update(1, 1, n, id[i], w[i]);
    }

    scanf("%d",&q);

    for(int i=1; i<=q; i++){
        char ch[10];
        scanf("%s",ch);

        if(ch[0] == 'C'){
            int u,k;
            scanf("%d%d",&u,&k);
            update(1, 1, n, id[u], k);
        }else if(ch[1] == 'S'){
            int u,v;
            scanf("%d%d",&u,&v);
            printf("%d\n",findSum(u,v));
        }else if(ch[1] == 'M'){
            int u,v;
            scanf("%d%d",&u,&v);
            printf("%d\n",findMax(u,v));
        }
    }
}