[题解] Luogu – 3398 仓鼠找sugar

题目描述

小仓鼠的和他的基(mei)友(zi)sugar 住在地下洞穴中,每个节点的编号为 $1 – n$ 。地下洞穴是一个树形结构。这一天小仓鼠打算从从他的卧室 $a$ 到餐厅 $b$,而他的基友同时要从他的卧室 $c$ 到图书馆 $d$ 。他们都会走最短路径。现在小仓鼠希望知道,有没有可能在某个地方,可以碰到他的基友?

小仓鼠那么弱,还要天天被 $zzq$ 大爷虐,请你快来救救他吧!

输入格式

第一行两个整数 $n$ 和 $q$,表示这棵树节点的个数和询问的个数
接下来 $n-1$ 行,每行两个正整数 $u$ 和 $v$ ,表示节点 $u$ 到节点 $v$ 之间有一条边。
接下来 $q$ 行,每行四个正整数 $a$、$b$ 、$c$ 和 $d$,表示节点编号,也就是一次询问,其意义如上。

输出格式

对于每个询问,如果有公共点,输出大写字母 $Y$,否则输出 $N$

样例输入

5 5
2 5
4 2
1 3
1 4
5 1 5 1
2 2 1 4
4 1 3 4
3 1 1 5
3 5 1 4

样例输出

Y
N
Y
Y
Y

解题思路

不难发现,如果两条路径相交,那么一条路径的 $LCA$ 一定在另一条路径上

如何判断一个点 $x$ 在另一条路径 $a-b$ 上,令 $a-b$ 的 $LCA$ 为 $c$,那么

$deep_x \geq deep_c $
$lca(a, x) = x$ 或 $lca(b,x) = x$

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 100000 + 10;
const int MAXM = 200000 + 10;

int Head[MAXN], to[MAXM], Next[MAXM], tot = 1;

inline void add(int a, int b){
    to[tot] = b;
    Next[tot] = Head[a];
    Head[a] = tot++;
}

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
}

int f[MAXN][20], deep[MAXN];

inline void dfs(int x, int fa){
    for(register int i=1; i<=17; i++){
        f[x][i] = f[f[x][i-1]][i-1];
    }   

    for(register int i=Head[x]; i; i=Next[i]){
        int v = to[i];
        if(v != fa){
            f[v][0] = x;
            deep[v] = deep[x] + 1;

            dfs(v, x);
        }       
    }
}

inline int LCA(int a, int b){
    if(deep[a] < deep[b]){
        swap(a, b);
    }

    for(register int i=17; i>=0; i--){
        if(deep[f[a][i]] >= deep[b]){
            a = f[a][i];
        }
    }

    if(a == b)
        return a;

    for(register int i=17; i>=0; i--){
        if(f[a][i] != f[b][i]){
            a = f[a][i];
            b = f[b][i];
        }
    }

    return f[a][0];
}

int n, q;

int main(){ 
    n = read();
    q = read();

    for(register int i=0; i<n-1; i++){
        int a = read();
        int b = read();

        add(a, b);
        add(b, a);
    }

    deep[1] = 1;
    dfs(1, 0);

    for(register int i=0; i<q; i++){
        int a = read();
        int b = read();
        int c = read();
        int d = read();

        int x = LCA(a, b);
        int y = LCA(c, d);

        if(deep[x] < deep[y]){
            swap(x, y);
            swap(a, c);
            swap(b, d);
        }

        if(LCA(x, c) == x || LCA(x, d) == x){
            printf("Y\n");
        }else{
            printf("N\n");
        }
    }
    return 0;
}