题目描述
一棵树,每个节点有一个颜色,q个询问
每次询问x,y
求树上距离最长的两点
其中一点color【i】=x;
另外一点color【i】=y;
思路
1.对于求树上两点之间的距离可以用lca,
倍增法求lca也不再介绍
2.求max(color[i],color[j]]
如果只看一种颜色x求最长的距离可以枚举是x颜色的每个点
我们暂且叫颜色x的直径
如果看两种颜色x,y 那么max是x直径的两端点中的其中一点和y直径中两端点中的其中一点

由于color[i]比较大要可以先离散化
代码

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define UpMing main
#define re register
#pragma GCC optimize(2)
#define Accept return 0;
#define lowbit(x) ((x)&(-(x)))
#define mst(x, a) memset( x,a,sizeof(x) )
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define dep(i,a,b) for(int i=(a);i>=(b);--i)
using namespace std;
typedef long long ll;
typedef pair PII;
typedef unsigned long long ull;
const int inf =0x3f3f3f3f;
const int maxn=4e5+66;
const ll mod = 1e9+7;
const int N =1e6+3;
inline ll read() {
    ll  x=0;
    bool f=0;
    char ch=getchar();
    while (ch<'0'||'9'<ch)    f|=ch=='-', ch=getchar();
    while ('0'<=ch && ch<='9')
        x=x*10+ch-'0',ch=getchar();
    return f?-x:x;
}
void out(ll x) {
    int stackk[20];
    if(x<0) {
        putchar('-');
        x=-x;
    }
    if(!x) {
        putchar('0');
        return;
    }
    int top=0;
    while(x) stackk[++top]=x%10,x/=10;
    while(top) putchar(stackk[top--]+'0');
}
int  n,q,a[maxn],b[maxn],fa[200002][200],len,depth[maxn];
vectorv[maxn],c[maxn];
void dfs(int  t,int  p) {
    depth[t]=depth[p]+1;
    fa[t][0]=p;
    for(int i=1 ; i<=len; i++) fa[t][i]=fa[fa[t][i-1]][i-1];
    for(auto u:v[t]) {
        if(u==p) continue;
        dfs(u,t);
    }
}
int lca(int  x,int  y) {
    if(depth[x]<depth[y]) swap(x,y);
    for(int i=len ; i>=0 ; i--) {
        if(depth[fa[x][i]]>=depth[y])
            x=fa[x][i];
    }
    if(x==y) return x;
    for(int i=len; i>=0 ; i--) {
        ll dx=fa[x][i];
        ll dy=fa[y][i];
        if(dx!=dy) x=dx,y=dy;
    }
    return fa[x][0];
}
int  dis(int  x,int  y) {
    return depth[x]+depth[y]-2*depth[lca(x,y)];
}
mapmp;
int UpMing() {
    n=read();
    q=read();
    len=log2(n)+1;
    for(int i=1 ; i<=n ; i++)
        a[i]=read(),b[i]=a[i],mp[a[i]]++;
    sort(b+1,b+1+n);
    ll pos=unique(b+1,b+1+n)-(b+1);
    for(int i=1 ; i<=n ; i++) {
        a[i]=lower_bound(b+1,b+1+pos,a[i])-b;
        c[a[i]].push_back(i);
    }
    for(int i=1 ; i<n ; i++) {
        ll x=read();
        ll y=read();
        v[x].push_back(y);
        v[y].push_back(x);
    }
    dfs(1,0);
    for(int i=1 ; i<=pos; i++) {
        for(int j=0 ; j<c[i].size(); j++) {
            ll dis1=dis(c[i][0],c[i][j]);
            ll dis2=dis(c[i][1],c[i][j]);
            ll dis3=dis(c[i][0],c[i][1]);
            if(dis1>dis3&&dis2>dis3) {
                if(dis1>dis2)  c[i][1]=c[i][j];
                else c[i][0]=c[i][j];
            } else if(dis1>dis3) c[i][1]=c[i][j];
            else if(dis2>dis3) c[i][0]=c[i][j];
        }
    }
    while(q--) {
        ll x=read();
        ll y=read();
        if(mp[x]==0||mp[y]==0) {
            printf("0\n");
            continue;
        }
        x=lower_bound(b+1,b+1+pos,x)-b;
        y=lower_bound(b+1,b+1+pos,y)-b;
        ll dis1=0;
        ll dis2=0;
        ll dis3=0;
        ll dis4=0;
        dis1=dis(c[x][0],c[y][0]);
        if(c[x].size()>1&&c[y].size()>1) {
            dis4=dis(c[x][1],c[y][1]);
            dis3=dis(c[x][0],c[y][1]);
            dis2=dis(c[x][1],c[y][0]);
        } else if(c[y].size()>1)
            dis3=dis(c[x][0],c[y][1]);
        else if(c[x].size()>1)
            dis2=dis(c[x][1],c[y][0]);
        out(max(max(max(dis4,dis3),dis1),dis2));
        printf("\n");
    }
    Accept;
}