链接:https://ac.nowcoder.com/acm/contest/3/B
来源:牛客网

Borrow Classroom
时间限制:C/C++ 3秒,其他语言6秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld
题目描述
每年的BNU校赛都会有两次赛前培训,为此就需要去借教室,由于SK同学忙于出题,这个事情就由小Q同学来跑腿。SK同学准备从宿舍出发,把借教室的单子交给小Q同学让他拿去教务处盖章,但是何老师突然发现SK同学好像借错教室了,想抢在借教室的单子被送到教务处之前拦截下来。

现在把校园抽象成一棵n个节点的树,每条边的长度都是一个单位长度,从1到n编号,其中教务处位于1号节点,接下来有q个询问,每次询问中SK同学会从B号节点出发,到C号节点找到小Q同学并将借教室的单子交给他,然后小Q同学再从C号节点出发前往教务处,何老师会从A号节点出发开始拦截。

所有人在一个单位时间内最多走一个单位距离,只要何老师在单子还没被送到教务处之前遇到拿着单子的同学都算拦截成功,如果小Q同学已经到了教务处,那么由于小Q同学手速极快,单子会被立即交上去,即使何老师到了教务处也无济于事,你需要判断何老师是否能够拦截成功。
输入描述:
第一行是一个正整数T(≤ 5),表示测试数据的组数, 对于每组测试数据, 第一行是两个整数n,q(1≤ n,q ≤ 100000),分别表示节点数和询问数, 接下来n-1行,每行包含两个整数x,y(1≤ x,y ≤ n),表示x和y之间有一条边相连,保证这些边能构成一棵树, 接下来q行,每行包含三个整数A,B,C(1 ≤ A,B,C ≤ n),表示一个询问,其中A是何老师所在位置,B是SK同学所在位置,C是小Q同学所在位置,保证小Q同学初始不在教务处。
输出描述:
对于每个询问,输出一行,如果何老师能成功拦截则输出"YES"(不含引号),否则输出"NO"(不含引号)。
示例1
输入
复制
1
7 2
1 2
2 3
3 4
4 7
1 5
1 6
3 5 6
7 5 6
输出
复制
YES
NO

题意:

思路:
我们把情况分为两类来讨论,A 和 C 最近公共祖先 是1 和不是1 的情况,

① 当A 节点和 C 节点的 LCA 是1 ,则表明,在A通向1节点的路径和 C 通向1节点的路径没有相交点,那么 A 想成功拦截就必须 A 到1 节点的距离 小于 B->C -> 1 的距离,
② LCA 不是1的时候,那么在通往1节点的路径上是有交点的,且交点在 到达1 之前,那么 我们只需要 A 到1 节点的距离 小于等于 B->C -> 1 的距离 即可 成功拦截。

求两节点距离我用的是倍增LCA 法 。

细节见代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#define ALL(x) (x).begin(), (x).end()
#define rt return
#define sz(a) int(a.size())
#define all(a) a.begin(), a.end()
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define gg(x) getInt(&x)
#define db(x) cout<<"== [ "<<x<<" ] =="<<endl;
using namespace std;
typedef long long ll;
ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
ll lcm(ll a,ll b){return a/gcd(a,b)*b;}
ll powmod(ll a,ll b,ll MOD){ll ans=1;while(b){if(b%2)ans=ans*a%MOD;a=a*a%MOD;b/=2;}return ans;}
inline void getInt(int* p);
const int maxn=1000010;
const int inf=0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/
std::vector<int> son[maxn];
int depth[maxn];
int far[maxn][22];
int n,q;
void dfs(int x,int pre)
{
//    cout<<x<<" "<<pre<<endl;
	depth[x]=depth[pre]+1;
	far[x][0]=pre;
	for(int i=1;i<20;i++)
	{
		far[x][i]=far[far[x][i-1]][i-1];
	}
//    for(int i=1;i<20;i++)
//    {
//        fa[rt][i]=fa[fa[rt][i-1]][i-1];
//    }
	for(auto y:son[x])
	{
		if(y!=pre)
		{
			dfs(y,x);
		}
	}
}
int lca(int x,int y)
{
	if(depth[x]<depth[y])
	{
		swap(x,y);
	}
	for(int i=19;i>=0;i--)
	{
		if(depth[x]-(1<<i)>=(depth[y]))
		{
			x=far[x][i];
		}
	}
	if(x==y)
	{
		return x;
	}
	for(int i=19;i>=0;i--)
	{
		if(far[x][i]!=far[y][i])
		{
			x=far[x][i];
			y=far[y][i];
		}
	}
	return far[x][0];
}
int getdist(int x,int y)
{
	int z=lca(x,y);
	int res=depth[x]+depth[y]-2*depth[z];
	return res;
}
int main()
{
//    freopen("D:\\code\\text\\input.txt","r",stdin);
	//freopen("D:\\code\\text\\output.txt","w",stdout);
	int t;
	gbtb;
	cin>>t;
	while(t--)
	{
		cin>>n>>q;
		int u,v;
		repd(i,1,n)
		{
			son[i].clear();
		}
		repd(i,2,n)
		{
			cin>>u>>v;
			son[v].pb(u);
			son[u].pb(v);
		}
		depth[0]=0;
		dfs(1,0);
		while(q--)
		{
			int a,b,c;
			cin>>a>>b>>c;
			int a1=getdist(a,1);
			int bc=getdist(b,c);
			int c1=getdist(c,1);
			int ac=getdist(a,c);
//			cout<<a1<<" "<<bc<<" "<<c1<<" "<<ac<<endl;
			if(a1<bc+c1)
			{
				cout<<"YES"<<endl;
			}else if(ac<=bc)
			{
				cout<<"YES"<<endl;
			}else if(lca(a,c)!=1&&a1<=bc+c1)
			{
				cout<<"YES"<<endl;
			}else
			{
				cout<<"NO"<<endl;
			}
		}
	}



    return 0;
}

inline void getInt(int* p) {
    char ch;
    do {
        ch = getchar();
    } while (ch == ' ' || ch == '\n');
    if (ch == '-') {
        *p = -(getchar() - '0');
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 - ch + '0';
        }
    }
    else {
        *p = ch - '0';
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 + ch - '0';
        }
    }
}