假设节点颜色为图片说明 ,对于求颜色图片说明 在树上距离最大的两点,我们需要先找出树上颜色为图片说明 的距离最大的两点,找出树上颜色为图片说明 的距离最大的两点,然后枚举求不同颜色的两点距离,得到的最大值就是图片说明 的最远距离。

所以,我们只需要预处理出所有颜色的最远两点,然后计算的时候用LCA求距离即可。

当然,颜色可能很多,多达图片说明 种,并且树的节点也是图片说明 ,如果直接一一枚举颜色,然后在原树上dp求最远距离,复杂度将会是图片说明 ,肯定不行,我们想,原树上是不是有很多节点,我们可能并不需要,我们不用耗费时间去遍历,所以,我们需要在寻找一个新颜色的时候,重构一棵虚树,把不需要的节点全部给抛弃(不是说颜色不是所求的颜色就不需要,有可能有一些连接节点等,但是对复杂度影响不大),这样复杂度就可以大大减少,全部用到的节点次数可能只是n的个位数倍(n ~ k*n,k大多数情况下都是个位数),加上建立需要需要用到LCA,所以这部分复杂度可以预估为图片说明 ,常数k部分不计算,这样预处理出全部颜色的最远两点,时间复杂度很可观。

接下来就是求最远距离了,两种颜色的四个点,枚举最多有4种,两两之间通过LCA求距离,最大的就是答案,并且LCA复杂度是图片说明 ,所以这里查询的复杂度也只是图片说明

#include <cctype>
#include <cfloat>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <algorithm>
#include <deque>
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <utility>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#define ll long long
#define pll pair<long long, long long>
#define P pair<int, int>
#define PP pair<P, P>
#define eps 1e-6
#define It set<node>::iterator
using namespace std;
inline int Read() {
    int res=0;
    char ch=getchar();
    while (!isdigit(ch)) ch=getchar();
    while (isdigit(ch)) {
        res=res*10+ch-'0';
        ch=getchar();
    }
    return res;
}
const int maxn=1e5+10;
const long long INF=1ll<<62;
int top,G[maxn],Stack[maxn],Head[maxn],tot,dp[maxn][23],dep[maxn],dfn[maxn],cnt,lg[maxn],ans[maxn][2],col[maxn],point[maxn][2];
struct edge {
    int to,Next;
}Edge[maxn<<1];
void Add(int from, int to) {
    Edge[tot]={to,Head[from]};
    Head[from]=tot++;
    Edge[tot]={from,Head[to]};
    Head[to]=tot++;
}
void dfs(int u, int fa) {
    dfn[u]=++cnt;
    dp[u][0]=fa;
    dep[u]=dep[fa]+1;
    for (int i=1; i<lg[dep[u]]; i++) dp[u][i]=dp[dp[u][i-1]][i-1];
    for (int i=Head[u]; ~i; i=Edge[i].Next) {
        int v=Edge[i].to;
        if (v==fa) continue;
        dfs(v,u);
    }
}
int LCA(int a, int b) {
    if (dep[a]<dep[b]) swap(a,b);
    while (dep[a]>dep[b]) {
        a=dp[a][lg[dep[a]-dep[b]]-1];
    }
    if (a==b) return a;
    for (int i=lg[dep[a]]-1; i>=0; i--) {
        if (dp[a][i]!=dp[b][i]) {
            a=dp[a][i];
            b=dp[b][i];
        }
    }
    return dp[a][0];
}
int res,col1,Max[maxn][2],p[maxn][2];
void solve1(int u, int fa) {
    p[u][0]=p[u][1]=0;
    Max[u][0]=Max[u][1]=0;
    ans[u][0]=ans[u][1]=0;
    for (int i=Head[u]; ~i; i=Edge[i].Next) {
        int v=Edge[i].to;
        if (v==fa) continue;
        solve1(v,u);
        if (ans[v][0]) {
            if (ans[u][0]<ans[v][0]+dep[v]-dep[u]) {
                p[u][1]=p[u][0];
                p[u][0]=p[v][0];
                Max[u][1]=Max[u][0];
                Max[u][0]=v;
                ans[u][1]=ans[u][0];
                ans[u][0]=ans[v][0]+dep[v]-dep[u];
            } else if (ans[u][1]<ans[v][0]+dep[v]-dep[u]) {
                ans[u][1]=ans[v][0]+dep[v]-dep[u];
                Max[u][1]=v;
                p[u][1]=p[v][0];
            }
        } else if (col[v]==col1) {
            if (ans[u][0]<dep[v]-dep[u]) {
                Max[u][1]=Max[u][0];
                Max[u][0]=v;
                p[u][1]=p[u][0];
                p[u][0]=v;
                ans[u][1]=ans[u][0];
                ans[u][0]=dep[v]-dep[u];
            } else if (ans[u][1]<dep[v]-dep[u]) {
                ans[u][1]=dep[v]-dep[u];
                Max[u][1]=v;
                p[u][1]=v;
            }
        }
    }
    if (Max[u][0]&&Max[u][1]) {
        if (res<ans[u][0]+ans[u][1]) {
            res=ans[u][0]+ans[u][1];
            point[col1][0]=p[u][0];
            point[col1][1]=p[u][1];
        }
    }
    if (col[u]==col1) {
        if (Max[u][0]) {
            if (res<ans[u][0]) {
                res=ans[u][0];
                point[col1][0]=p[u][0];
                point[col1][1]=u;
            }
        }
    }
    if (col[u]==col1&&point[col1][0]==0) point[col1][0]=u;
}
bool myfun(int a, int b) {
    return dfn[a]<dfn[b];
}
vector<int> vc[maxn];
map<int,int> M;
map<P,int> Ans;
int main() {
    int n,k;
    for (int i=1; i<maxn; i++) {
        lg[i]=lg[i-1]+(i==(1<<lg[i-1]));
    }
    memset(Head,-1,sizeof(Head));
    n=Read();
    k=Read();
    int idx=0;
    for (int i=1; i<=n; i++) {
        col[i]=Read();
        if (!M.count(col[i])) M[col[i]]=++idx;
        col[i]=M[col[i]];
        vc[col[i]].push_back(i);
    }
    for (int i=1; i<n; i++) {
        int u,v;
        u=Read(),v=Read();
        Add(u,v);
    }
    dfs(1,0);
    for (int j=1; j<=idx; j++) {
        int pos=0;
        col1=j;
        for (int i=0; i<vc[j].size(); i++) {
            int num=vc[j][i];
            G[pos++]=num;
            Head[num]=-1;
        }
        sort(G,G+pos,myfun);
        Stack[top=1]=1;
        Head[1]=-1;
        tot=0;
        for (int i=0; i<pos; i++) {
            if (G[i]==1) continue;
            int lca=LCA(Stack[top],G[i]);
            while (top>1&&dfn[lca]<=dfn[Stack[top-1]]) {
                Add(Stack[top],Stack[top-1]);
                top--;
            }
            if (top>1&&dfn[lca]!=dfn[Stack[top]]) {
                Head[lca]=-1;
                Add(lca,Stack[top]);
                Stack[top]=lca;
            }
            Stack[++top]=G[i];
        }
        while (top>1) {
            Add(Stack[top],Stack[top-1]);
            top--;
        }
        res=0;
        solve1(1,0);
    }
    while (k--) {
        int col1,col2;
        col1=Read(),col2=Read();
        if (!M.count(col1)||!M.count(col2)) {
            printf("0\n");
            continue;
        }
        col1=M[col1];
        col2=M[col2];
        if (col1==col2) {
            int res=0;
            if (point[col1][0]&&point[col1][1]) res=dep[point[col1][0]]+dep[point[col1][1]]-2*dep[LCA(point[col1][0],point[col1][1])];
            printf("%d\n",res);
        } else {
            int res=0;
            if (point[col1][0]&&point[col2][0]) res=max(res,dep[point[col1][0]]+dep[point[col2][0]]-2*dep[LCA(point[col1][0],point[col2][0])]);
            if (point[col1][0]&&point[col2][1]) res=max(res,dep[point[col1][0]]+dep[point[col2][1]]-2*dep[LCA(point[col1][0],point[col2][1])]);
            if (point[col1][1]&&point[col2][0]) res=max(res,dep[point[col1][1]]+dep[point[col2][0]]-2*dep[LCA(point[col1][1],point[col2][0])]);
            if (point[col1][1]&&point[col2][1]) res=max(res,dep[point[col1][1]]+dep[point[col2][1]]-2*dep[LCA(point[col1][1],point[col2][1])]);
            printf("%d\n",res);
        }
    }
    return 0;
}