赛后补C题,这题型怎么越看越像虚树呢,然后试着用虚树+两次dfs找直径 AC了
将相同势力的点拿出来建虚树,然后树上找直径即可。
但是这树没说根,我们假设1为根,找直径的时候特判一下1节点就可以了。
#include<bits/stdc++.h> using namespace std; typedef long long ll; const ll inf=0x3f3f3f3; const int N=2e5+10; int dfn[N],numdfn; int n,max_deep; int a[N]; /* lca部分 */ struct LCA { int head[N],dep[N],fa[N][22],cnt; struct edge { int v,w,next; }e[N*2]; void add(int u,int v,int w) { e[++cnt]={v,w,head[u]};head[u]=cnt; e[++cnt]={u,w,head[v]};head[v]=cnt; } void dfs(int u,int d,int f) { dep[u]=d;fa[u][0]=f; dfn[u]=++numdfn; max_deep=max(max_deep,d); for(int i=head[u];i;i=e[i].next) { int v=e[i].v; if(v==f) continue; dfs(v,d+1,u); } } void init() { dfs(1,0,1); for(int k=1;k<=20;++k) for(int i=1;i<=n;++i) fa[i][k]=fa[fa[i][k-1]][k-1]; } int lca(int u,int v) { if(dep[u]<dep[v]) swap(u,v); for(int k=20;k>=0;--k) if(dep[fa[u][k]]>=dep[v]) u=fa[u][k]; if(u==v) return v; for(int k=20;k>=0;--k) if(fa[u][k]!=fa[v][k]) u=fa[u][k],v=fa[v][k]; return fa[u][0]; } int findlen(int u,int v) { return dep[u]+dep[v]-2*dep[lca(u,v)]; } }L; /* 虚树部分 */ vector<int>has[N]; vector<pair<int,int> >G[N]; int st[N],top,cnt; int numid; bool cmp(int x,int y) { return dfn[x]<dfn[y]; } void ins(int id,int x)//插入虚树新点 { if(top==1){st[++top]=x;return ;} int fa=L.lca(x,st[top]); if(fa==st[top]){//就在该链的下方 st[++top]=x;return ; } while(top>1&&dfn[st[top-1]]>=dfn[fa]){ int p1=st[top],p2=st[top-1];//该相邻两节点相连 //建虚树图 int w=L.findlen(p1,p2); G[p1].emplace_back(make_pair(p2,w)); G[p2].emplace_back(make_pair(p1,w)); --top; } if(st[top]!=fa){ int p1=st[top],p2=fa; int w=L.findlen(p1,p2); G[p1].emplace_back(make_pair(p2,w)); G[p2].emplace_back(make_pair(p1,w)); st[top]=fa; } st[++top]=x; } void build(int id,vector<int>& has) { sort(has.begin(),has.end()); has.erase(unique(has.begin(),has.end()),has.end()); sort(has.begin(),has.end(),cmp); st[top=1]=1; //top=0; for(int v:has) ins(id,v); //printf("top:%d\n",top); while(top>1) { int p1=st[top],p2=st[top-1]; int w=L.findlen(p1,p2); G[p1].emplace_back(make_pair(p2,w)); G[p2].emplace_back(make_pair(p1,w)); --top; } } int mx,mi; int tar; void dfs1(int u,int fa,int dep,int val,int ty) { //printf("u:%d fa:%d\n",u,fa); if(a[u]==ty&&dep>mi){ mi=dep,tar=u; } for(auto now:G[u]){ if(now.first==fa) continue; if(now.first==u) continue;//1可能跟自己又连了一条边 //printf("fa:%d u:%d v:%d w:%d\n",fa,u,now.first,now.second); dfs1(now.first,u,dep+now.second,val,ty); } if(val==1) G[u].clear(); } int bfs(int rt,int ty) { mi=-1; dfs1(rt,-1,0,0,ty); mi=-1; // puts(""); // printf("tar:%d\n",tar); // puts(""); dfs1(tar,-1,0,1,ty); // printf("ty:%d mi:%d\n",ty,mi); // puts(""); // puts(""); return mi; } int main() { scanf("%d",&n); for(int i=1;i<=n;++i) { scanf("%d",&a[i]); has[a[i]].push_back(i); mx=max(mx,a[i]); } for(int i=1;i<n;++i){ int u,v; scanf("%d%d",&u,&v); L.add(u,v,1); } L.init(); int ans=0; for(int i=1;i<=mx;++i){ if(has[i].size()==0) continue; build(i,has[i]); ans=max(ans,bfs(has[i][0],i)); } printf("%lld\n",1ll*ans*ans); } /* 5 2 1 2 1 1 1 2 2 3 3 4 4 5 */