假设节点颜色为 ,对于求颜色
在树上距离最大的两点,我们需要先找出树上颜色为
的距离最大的两点,找出树上颜色为
的距离最大的两点,然后枚举求不同颜色的两点距离,得到的最大值就是
的最远距离。
所以,我们只需要预处理出所有颜色的最远两点,然后计算的时候用LCA求距离即可。
当然,颜色可能很多,多达 种,并且树的节点也是
,如果直接一一枚举颜色,然后在原树上dp求最远距离,复杂度将会是
,肯定不行,我们想,原树上是不是有很多节点,我们可能并不需要,我们不用耗费时间去遍历,所以,我们需要在寻找一个新颜色的时候,重构一棵虚树,把不需要的节点全部给抛弃(不是说颜色不是所求的颜色就不需要,有可能有一些连接节点等,但是对复杂度影响不大),这样复杂度就可以大大减少,全部用到的节点次数可能只是n的个位数倍(n ~ k*n,k大多数情况下都是个位数),加上建立需要需要用到LCA,所以这部分复杂度可以预估为
,常数k部分不计算,这样预处理出全部颜色的最远两点,时间复杂度很可观。
接下来就是求最远距离了,两种颜色的四个点,枚举最多有4种,两两之间通过LCA求距离,最大的就是答案,并且LCA复杂度是 ,所以这里查询的复杂度也只是
。
#include <cctype>
#include <cfloat>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <algorithm>
#include <deque>
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <utility>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#define ll long long
#define pll pair<long long, long long>
#define P pair<int, int>
#define PP pair<P, P>
#define eps 1e-6
#define It set<node>::iterator
using namespace std;
inline int Read() {
int res=0;
char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) {
res=res*10+ch-'0';
ch=getchar();
}
return res;
}
const int maxn=1e5+10;
const long long INF=1ll<<62;
int top,G[maxn],Stack[maxn],Head[maxn],tot,dp[maxn][23],dep[maxn],dfn[maxn],cnt,lg[maxn],ans[maxn][2],col[maxn],point[maxn][2];
struct edge {
int to,Next;
}Edge[maxn<<1];
void Add(int from, int to) {
Edge[tot]={to,Head[from]};
Head[from]=tot++;
Edge[tot]={from,Head[to]};
Head[to]=tot++;
}
void dfs(int u, int fa) {
dfn[u]=++cnt;
dp[u][0]=fa;
dep[u]=dep[fa]+1;
for (int i=1; i<lg[dep[u]]; i++) dp[u][i]=dp[dp[u][i-1]][i-1];
for (int i=Head[u]; ~i; i=Edge[i].Next) {
int v=Edge[i].to;
if (v==fa) continue;
dfs(v,u);
}
}
int LCA(int a, int b) {
if (dep[a]<dep[b]) swap(a,b);
while (dep[a]>dep[b]) {
a=dp[a][lg[dep[a]-dep[b]]-1];
}
if (a==b) return a;
for (int i=lg[dep[a]]-1; i>=0; i--) {
if (dp[a][i]!=dp[b][i]) {
a=dp[a][i];
b=dp[b][i];
}
}
return dp[a][0];
}
int res,col1,Max[maxn][2],p[maxn][2];
void solve1(int u, int fa) {
p[u][0]=p[u][1]=0;
Max[u][0]=Max[u][1]=0;
ans[u][0]=ans[u][1]=0;
for (int i=Head[u]; ~i; i=Edge[i].Next) {
int v=Edge[i].to;
if (v==fa) continue;
solve1(v,u);
if (ans[v][0]) {
if (ans[u][0]<ans[v][0]+dep[v]-dep[u]) {
p[u][1]=p[u][0];
p[u][0]=p[v][0];
Max[u][1]=Max[u][0];
Max[u][0]=v;
ans[u][1]=ans[u][0];
ans[u][0]=ans[v][0]+dep[v]-dep[u];
} else if (ans[u][1]<ans[v][0]+dep[v]-dep[u]) {
ans[u][1]=ans[v][0]+dep[v]-dep[u];
Max[u][1]=v;
p[u][1]=p[v][0];
}
} else if (col[v]==col1) {
if (ans[u][0]<dep[v]-dep[u]) {
Max[u][1]=Max[u][0];
Max[u][0]=v;
p[u][1]=p[u][0];
p[u][0]=v;
ans[u][1]=ans[u][0];
ans[u][0]=dep[v]-dep[u];
} else if (ans[u][1]<dep[v]-dep[u]) {
ans[u][1]=dep[v]-dep[u];
Max[u][1]=v;
p[u][1]=v;
}
}
}
if (Max[u][0]&&Max[u][1]) {
if (res<ans[u][0]+ans[u][1]) {
res=ans[u][0]+ans[u][1];
point[col1][0]=p[u][0];
point[col1][1]=p[u][1];
}
}
if (col[u]==col1) {
if (Max[u][0]) {
if (res<ans[u][0]) {
res=ans[u][0];
point[col1][0]=p[u][0];
point[col1][1]=u;
}
}
}
if (col[u]==col1&&point[col1][0]==0) point[col1][0]=u;
}
bool myfun(int a, int b) {
return dfn[a]<dfn[b];
}
vector<int> vc[maxn];
map<int,int> M;
map<P,int> Ans;
int main() {
int n,k;
for (int i=1; i<maxn; i++) {
lg[i]=lg[i-1]+(i==(1<<lg[i-1]));
}
memset(Head,-1,sizeof(Head));
n=Read();
k=Read();
int idx=0;
for (int i=1; i<=n; i++) {
col[i]=Read();
if (!M.count(col[i])) M[col[i]]=++idx;
col[i]=M[col[i]];
vc[col[i]].push_back(i);
}
for (int i=1; i<n; i++) {
int u,v;
u=Read(),v=Read();
Add(u,v);
}
dfs(1,0);
for (int j=1; j<=idx; j++) {
int pos=0;
col1=j;
for (int i=0; i<vc[j].size(); i++) {
int num=vc[j][i];
G[pos++]=num;
Head[num]=-1;
}
sort(G,G+pos,myfun);
Stack[top=1]=1;
Head[1]=-1;
tot=0;
for (int i=0; i<pos; i++) {
if (G[i]==1) continue;
int lca=LCA(Stack[top],G[i]);
while (top>1&&dfn[lca]<=dfn[Stack[top-1]]) {
Add(Stack[top],Stack[top-1]);
top--;
}
if (top>1&&dfn[lca]!=dfn[Stack[top]]) {
Head[lca]=-1;
Add(lca,Stack[top]);
Stack[top]=lca;
}
Stack[++top]=G[i];
}
while (top>1) {
Add(Stack[top],Stack[top-1]);
top--;
}
res=0;
solve1(1,0);
}
while (k--) {
int col1,col2;
col1=Read(),col2=Read();
if (!M.count(col1)||!M.count(col2)) {
printf("0\n");
continue;
}
col1=M[col1];
col2=M[col2];
if (col1==col2) {
int res=0;
if (point[col1][0]&&point[col1][1]) res=dep[point[col1][0]]+dep[point[col1][1]]-2*dep[LCA(point[col1][0],point[col1][1])];
printf("%d\n",res);
} else {
int res=0;
if (point[col1][0]&&point[col2][0]) res=max(res,dep[point[col1][0]]+dep[point[col2][0]]-2*dep[LCA(point[col1][0],point[col2][0])]);
if (point[col1][0]&&point[col2][1]) res=max(res,dep[point[col1][0]]+dep[point[col2][1]]-2*dep[LCA(point[col1][0],point[col2][1])]);
if (point[col1][1]&&point[col2][0]) res=max(res,dep[point[col1][1]]+dep[point[col2][0]]-2*dep[LCA(point[col1][1],point[col2][0])]);
if (point[col1][1]&&point[col2][1]) res=max(res,dep[point[col1][1]]+dep[point[col2][1]]-2*dep[LCA(point[col1][1],point[col2][1])]);
printf("%d\n",res);
}
}
return 0;
}


京公网安备 11010502036488号