假设节点颜色为 ,对于求颜色 在树上距离最大的两点,我们需要先找出树上颜色为 的距离最大的两点,找出树上颜色为 的距离最大的两点,然后枚举求不同颜色的两点距离,得到的最大值就是 的最远距离。
所以,我们只需要预处理出所有颜色的最远两点,然后计算的时候用LCA求距离即可。
当然,颜色可能很多,多达 种,并且树的节点也是 ,如果直接一一枚举颜色,然后在原树上dp求最远距离,复杂度将会是 ,肯定不行,我们想,原树上是不是有很多节点,我们可能并不需要,我们不用耗费时间去遍历,所以,我们需要在寻找一个新颜色的时候,重构一棵虚树,把不需要的节点全部给抛弃(不是说颜色不是所求的颜色就不需要,有可能有一些连接节点等,但是对复杂度影响不大),这样复杂度就可以大大减少,全部用到的节点次数可能只是n的个位数倍(n ~ k*n,k大多数情况下都是个位数),加上建立需要需要用到LCA,所以这部分复杂度可以预估为 ,常数k部分不计算,这样预处理出全部颜色的最远两点,时间复杂度很可观。
接下来就是求最远距离了,两种颜色的四个点,枚举最多有4种,两两之间通过LCA求距离,最大的就是答案,并且LCA复杂度是 ,所以这里查询的复杂度也只是 。
#include <cctype> #include <cfloat> #include <cmath> #include <cstdio> #include <cstdlib> #include <cstring> #include <ctime> #include <algorithm> #include <deque> #include <fstream> #include <functional> #include <iomanip> #include <iostream> #include <istream> #include <iterator> #include <list> #include <map> #include <ostream> #include <queue> #include <set> #include <sstream> #include <stack> #include <string> #include <utility> #include <vector> #include <unordered_map> #include <unordered_set> #define ll long long #define pll pair<long long, long long> #define P pair<int, int> #define PP pair<P, P> #define eps 1e-6 #define It set<node>::iterator using namespace std; inline int Read() { int res=0; char ch=getchar(); while (!isdigit(ch)) ch=getchar(); while (isdigit(ch)) { res=res*10+ch-'0'; ch=getchar(); } return res; } const int maxn=1e5+10; const long long INF=1ll<<62; int top,G[maxn],Stack[maxn],Head[maxn],tot,dp[maxn][23],dep[maxn],dfn[maxn],cnt,lg[maxn],ans[maxn][2],col[maxn],point[maxn][2]; struct edge { int to,Next; }Edge[maxn<<1]; void Add(int from, int to) { Edge[tot]={to,Head[from]}; Head[from]=tot++; Edge[tot]={from,Head[to]}; Head[to]=tot++; } void dfs(int u, int fa) { dfn[u]=++cnt; dp[u][0]=fa; dep[u]=dep[fa]+1; for (int i=1; i<lg[dep[u]]; i++) dp[u][i]=dp[dp[u][i-1]][i-1]; for (int i=Head[u]; ~i; i=Edge[i].Next) { int v=Edge[i].to; if (v==fa) continue; dfs(v,u); } } int LCA(int a, int b) { if (dep[a]<dep[b]) swap(a,b); while (dep[a]>dep[b]) { a=dp[a][lg[dep[a]-dep[b]]-1]; } if (a==b) return a; for (int i=lg[dep[a]]-1; i>=0; i--) { if (dp[a][i]!=dp[b][i]) { a=dp[a][i]; b=dp[b][i]; } } return dp[a][0]; } int res,col1,Max[maxn][2],p[maxn][2]; void solve1(int u, int fa) { p[u][0]=p[u][1]=0; Max[u][0]=Max[u][1]=0; ans[u][0]=ans[u][1]=0; for (int i=Head[u]; ~i; i=Edge[i].Next) { int v=Edge[i].to; if (v==fa) continue; solve1(v,u); if (ans[v][0]) { if (ans[u][0]<ans[v][0]+dep[v]-dep[u]) { p[u][1]=p[u][0]; p[u][0]=p[v][0]; Max[u][1]=Max[u][0]; Max[u][0]=v; ans[u][1]=ans[u][0]; ans[u][0]=ans[v][0]+dep[v]-dep[u]; } else if (ans[u][1]<ans[v][0]+dep[v]-dep[u]) { ans[u][1]=ans[v][0]+dep[v]-dep[u]; Max[u][1]=v; p[u][1]=p[v][0]; } } else if (col[v]==col1) { if (ans[u][0]<dep[v]-dep[u]) { Max[u][1]=Max[u][0]; Max[u][0]=v; p[u][1]=p[u][0]; p[u][0]=v; ans[u][1]=ans[u][0]; ans[u][0]=dep[v]-dep[u]; } else if (ans[u][1]<dep[v]-dep[u]) { ans[u][1]=dep[v]-dep[u]; Max[u][1]=v; p[u][1]=v; } } } if (Max[u][0]&&Max[u][1]) { if (res<ans[u][0]+ans[u][1]) { res=ans[u][0]+ans[u][1]; point[col1][0]=p[u][0]; point[col1][1]=p[u][1]; } } if (col[u]==col1) { if (Max[u][0]) { if (res<ans[u][0]) { res=ans[u][0]; point[col1][0]=p[u][0]; point[col1][1]=u; } } } if (col[u]==col1&&point[col1][0]==0) point[col1][0]=u; } bool myfun(int a, int b) { return dfn[a]<dfn[b]; } vector<int> vc[maxn]; map<int,int> M; map<P,int> Ans; int main() { int n,k; for (int i=1; i<maxn; i++) { lg[i]=lg[i-1]+(i==(1<<lg[i-1])); } memset(Head,-1,sizeof(Head)); n=Read(); k=Read(); int idx=0; for (int i=1; i<=n; i++) { col[i]=Read(); if (!M.count(col[i])) M[col[i]]=++idx; col[i]=M[col[i]]; vc[col[i]].push_back(i); } for (int i=1; i<n; i++) { int u,v; u=Read(),v=Read(); Add(u,v); } dfs(1,0); for (int j=1; j<=idx; j++) { int pos=0; col1=j; for (int i=0; i<vc[j].size(); i++) { int num=vc[j][i]; G[pos++]=num; Head[num]=-1; } sort(G,G+pos,myfun); Stack[top=1]=1; Head[1]=-1; tot=0; for (int i=0; i<pos; i++) { if (G[i]==1) continue; int lca=LCA(Stack[top],G[i]); while (top>1&&dfn[lca]<=dfn[Stack[top-1]]) { Add(Stack[top],Stack[top-1]); top--; } if (top>1&&dfn[lca]!=dfn[Stack[top]]) { Head[lca]=-1; Add(lca,Stack[top]); Stack[top]=lca; } Stack[++top]=G[i]; } while (top>1) { Add(Stack[top],Stack[top-1]); top--; } res=0; solve1(1,0); } while (k--) { int col1,col2; col1=Read(),col2=Read(); if (!M.count(col1)||!M.count(col2)) { printf("0\n"); continue; } col1=M[col1]; col2=M[col2]; if (col1==col2) { int res=0; if (point[col1][0]&&point[col1][1]) res=dep[point[col1][0]]+dep[point[col1][1]]-2*dep[LCA(point[col1][0],point[col1][1])]; printf("%d\n",res); } else { int res=0; if (point[col1][0]&&point[col2][0]) res=max(res,dep[point[col1][0]]+dep[point[col2][0]]-2*dep[LCA(point[col1][0],point[col2][0])]); if (point[col1][0]&&point[col2][1]) res=max(res,dep[point[col1][0]]+dep[point[col2][1]]-2*dep[LCA(point[col1][0],point[col2][1])]); if (point[col1][1]&&point[col2][0]) res=max(res,dep[point[col1][1]]+dep[point[col2][0]]-2*dep[LCA(point[col1][1],point[col2][0])]); if (point[col1][1]&&point[col2][1]) res=max(res,dep[point[col1][1]]+dep[point[col2][1]]-2*dep[LCA(point[col1][1],point[col2][1])]); printf("%d\n",res); } } return 0; }