题干:

时间限制:10000ms

单点时限:1000ms

内存限制:256MB

描述

给定一棵N的节点的树,节点编号1~N,并且1号节点是根节点。  

小Hi会反复询问小Ho一个问题:给定两个节点a和b,有多少对节点c和d满足c < d且c到d的路径包含完整的a到b的路径?

你能帮帮小Ho吗?

输入

第一行包含两个数N和M,依次是节点总数和问题总数。  

第2~N行每行包含两个整数u和v,代表u是v的父节点。  

以下M行每行包含两个整数a和b,代表一个问题。

对于30%的数据,1 ≤ N, M ≤ 1000

对于100%的数据,1 ≤ N, M ≤ 100000

输出

对于每个问题输出一个整数,代表答案。

样例输入

7 2 
1 2  
1 3  
2 4  
2 5  
3 6  
3 7  
2 3  
2 4

样例输出

9  
6

解题报告:

  跑半遍LCA,到他俩深度相同的时候停止。然后判断是否在同一条链上,分别返回不同的答案就行了。注意在同一条链的时候,不能用u和v的,需要用u(深度大的)的和对应链上dep[v]+1的那个点的。(来看几发错误代码)

错误代码1:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<string>
#include<cmath>
#include<cstring>
#define ll long long
#define pb push_back
#define pm make_pair
using namespace std;
const int MAX = 2e5 + 5;
ll a[MAX];
int fa[MAX][33],dep[MAX],sum[MAX];
int n,m;
vector<int> vv[MAX];
void dfs(int cur,int rt) {
	dep[cur] = dep[rt]+1;
	sum[cur] = 1;
	for(int i = 1; i<=31; i++) {
		fa[cur][i] = fa[fa[cur][i-1]][i-1];
	}
	int sz = vv[cur].size();
	for(int i = 0; i<sz; i++) {
		int v = vv[cur][i];
		if(v == rt) continue;
		dfs(v,cur);
		sum[cur] += sum[v];
	}
}
int lca(int u,int v) {
	
	if(dep[u] < dep[v]) swap(u,v);
	ll ans1 = 1LL*sum[u] * (n - sum[v] + (sum[v]-sum[u]));
	ll ans2 = 1LL*sum[u] * sum[v];
	int dc = dep[u] - dep[v];
	for(int i = 0; i<=31; i++) {
		if(dc>>i & 1) u = fa[u][i];
	}
	if(u == v) {
		return ans1;
	}
	else return ans2;
	for(int i = 31; i>=0 && u != v ; i--) {
		if(fa[u][i] != fa[v][i]) {
			u = fa[u][i];
			v = fa[v][i];
		}
	}
	int res = fa[u][0];//u和v的最近公共祖先.
	 
}
int main()
{
	cin>>n>>m;
	for(int u,v,i = 1; i<=n-1; i++) {
		scanf("%d%d",&u,&v);
		fa[v][0] = u;
		vv[u].pb(v);
		vv[v].pb(u);
	}
	dfs(1,0);
	while(m--) {
		int u,v;
		cin>>u>>v;
		cout << lca(u,v) <<endl;
	}
	return 0 ;
 }

错误代码2:

ll lca(int u,int v) {
	
	if(dep[u] < dep[v]) swap(u,v);
	//ll ans1 = (n - sum[v] + (sum[v]-sum[u]));
	ll sumu = sum[u] ,sumv = sum[v];
	int dc = dep[u] - dep[v];
	for(int i = 31; i>=0; i--) {
		//if(dc>>i & 1) u = fa[u][i];
		if(dep[u]-1 != dep[v]) {
			u = fa[u][i];
		} 
	}
	if(fa[u][0] == v) {//说明在一条链上 
		return sumu*(n-sum[u]);
	}
	else return sumu*sumv;
	for(int i = 31; i>=0 && u != v ; i--) {
		if(fa[u][i] != fa[v][i]) {
			u = fa[u][i];
			v = fa[v][i];
		}
	}
	int res = fa[u][0];//u和v的最近公共祖先.
}

错误代码3:

ll lca(int u,int v) {
	
	if(dep[u] < dep[v]) swap(u,v);
	//ll ans1 = (n - sum[v] + (sum[v]-sum[u]));
	ll sumu = sum[u] ,sumv = sum[v];
	int dc = dep[u] - dep[v];
	for(int i = 31; i>=0; i--) {
		//if(dc>>i & 1) u = fa[u][i];
		if(dep[u] < dep[v]) {
			u = fa[u][i];
		} 
	}
	if(fa[u][0] == v) {//说明在一条链上 
		return sumu*(n-sum[u]);
	}
	else return sumu*sumv;
	for(int i = 31; i>=0 && u != v ; i--) {
		if(fa[u][i] != fa[v][i]) {
			u = fa[u][i];
			v = fa[v][i];
		}
	}
	int res = fa[u][0];//u和v的最近公共祖先.
}

错误代码4:

ll lca(int u,int v) {
	if(dep[u] < dep[v]) swap(u,v);
	ll sumu = sum[u] ,sumv = sum[v];
	for(int i = 31; i>=0 && dep[v] + 1 <= dep[u]; i--) {
		if(dep[fa[u][i]] < dep[v]) {
			u = fa[u][i];
		} 
	}
	if(fa[u][0] == v) {//说明在一条链上 
		return sumu*(n-sum[u]);
	}
	else return sumu*sumv;
}

AC代码:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<string>
#include<cmath>
#include<cstring>
#define ll long long
#define pb push_back
#define pm make_pair
using namespace std;
const int MAX = 2e5 + 5;
ll a[MAX];
int fa[MAX][33],dep[MAX],sum[MAX];
int n,m;
vector<int> vv[MAX];
void dfs(int cur,int rt) {
	dep[cur] = dep[rt]+1;
	sum[cur] = 1;
	for(int i = 1; i<=31; i++) {
		fa[cur][i] = fa[fa[cur][i-1]][i-1];
	}
	int sz = vv[cur].size();
	for(int i = 0; i<sz; i++) {
		int v = vv[cur][i];
		if(v == rt) continue;
		dfs(v,cur);
		sum[cur] += sum[v];
	}
}
ll lca(int u,int v) {
	if(dep[u] < dep[v]) swap(u,v);
	ll sumu = sum[u] ,sumv = sum[v];
	for(int i = 31; i>=0 && dep[v] + 1 <= dep[u]; i--) {
		if(dep[fa[u][i]] < dep[v]) {
			u = fa[u][i];
		} 
	}
	if(fa[u][0] == v) {//说明在一条链上 
		return sumu*(n-sum[u]);
	}
	else return sumu*sumv;
}
int main()
{
	cin>>n>>m;
	for(int u,v,i = 1; i<=n-1; i++) {
		scanf("%d%d",&u,&v);
		fa[v][0] = u;
		vv[u].pb(v);
		vv[v].pb(u);
	}
	dfs(1,0);
	while(m--) {
		int u,v;
		scanf("%d%d",&u,&v);//cin>>u>>v;
		cout << lca(u,v) <<endl;
	}
	return 0 ;
 }

总结:

  另外注意一下,,,对于一棵树,dep值大的是在下面啊(远离根),别想反了。