赛后补C题,这题型怎么越看越像虚树呢,然后试着用虚树+两次dfs找直径 AC了
将相同势力的点拿出来建虚树,然后树上找直径即可。
但是这树没说根,我们假设1为根,找直径的时候特判一下1节点就可以了。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll inf=0x3f3f3f3;
const int N=2e5+10;
int dfn[N],numdfn;
int n,max_deep;
int a[N];
/*
lca部分
*/
struct LCA
{
int head[N],dep[N],fa[N][22],cnt;
struct edge
{
int v,w,next;
}e[N*2];
void add(int u,int v,int w)
{
e[++cnt]={v,w,head[u]};head[u]=cnt;
e[++cnt]={u,w,head[v]};head[v]=cnt;
}
void dfs(int u,int d,int f)
{
dep[u]=d;fa[u][0]=f;
dfn[u]=++numdfn;
max_deep=max(max_deep,d);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==f) continue;
dfs(v,d+1,u);
}
}
void init()
{
dfs(1,0,1);
for(int k=1;k<=20;++k)
for(int i=1;i<=n;++i)
fa[i][k]=fa[fa[i][k-1]][k-1];
}
int lca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
for(int k=20;k>=0;--k)
if(dep[fa[u][k]]>=dep[v]) u=fa[u][k];
if(u==v) return v;
for(int k=20;k>=0;--k)
if(fa[u][k]!=fa[v][k]) u=fa[u][k],v=fa[v][k];
return fa[u][0];
}
int findlen(int u,int v)
{
return dep[u]+dep[v]-2*dep[lca(u,v)];
}
}L;
/*
虚树部分
*/
vector<int>has[N];
vector<pair<int,int> >G[N];
int st[N],top,cnt;
int numid;
bool cmp(int x,int y)
{
return dfn[x]<dfn[y];
}
void ins(int id,int x)//插入虚树新点
{
if(top==1){st[++top]=x;return ;}
int fa=L.lca(x,st[top]);
if(fa==st[top]){//就在该链的下方
st[++top]=x;return ;
}
while(top>1&&dfn[st[top-1]]>=dfn[fa]){
int p1=st[top],p2=st[top-1];//该相邻两节点相连
//建虚树图
int w=L.findlen(p1,p2);
G[p1].emplace_back(make_pair(p2,w));
G[p2].emplace_back(make_pair(p1,w));
--top;
}
if(st[top]!=fa){
int p1=st[top],p2=fa;
int w=L.findlen(p1,p2);
G[p1].emplace_back(make_pair(p2,w));
G[p2].emplace_back(make_pair(p1,w));
st[top]=fa;
}
st[++top]=x;
}
void build(int id,vector<int>& has)
{
sort(has.begin(),has.end());
has.erase(unique(has.begin(),has.end()),has.end());
sort(has.begin(),has.end(),cmp);
st[top=1]=1;
//top=0;
for(int v:has) ins(id,v);
//printf("top:%d\n",top);
while(top>1)
{
int p1=st[top],p2=st[top-1];
int w=L.findlen(p1,p2);
G[p1].emplace_back(make_pair(p2,w));
G[p2].emplace_back(make_pair(p1,w));
--top;
}
}
int mx,mi;
int tar;
void dfs1(int u,int fa,int dep,int val,int ty)
{
//printf("u:%d fa:%d\n",u,fa);
if(a[u]==ty&&dep>mi){
mi=dep,tar=u;
}
for(auto now:G[u]){
if(now.first==fa) continue;
if(now.first==u) continue;//1可能跟自己又连了一条边
//printf("fa:%d u:%d v:%d w:%d\n",fa,u,now.first,now.second);
dfs1(now.first,u,dep+now.second,val,ty);
}
if(val==1) G[u].clear();
}
int bfs(int rt,int ty)
{
mi=-1;
dfs1(rt,-1,0,0,ty);
mi=-1;
// puts("");
// printf("tar:%d\n",tar);
// puts("");
dfs1(tar,-1,0,1,ty);
// printf("ty:%d mi:%d\n",ty,mi);
// puts("");
// puts("");
return mi;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;++i) {
scanf("%d",&a[i]);
has[a[i]].push_back(i);
mx=max(mx,a[i]);
}
for(int i=1;i<n;++i){
int u,v;
scanf("%d%d",&u,&v);
L.add(u,v,1);
}
L.init();
int ans=0;
for(int i=1;i<=mx;++i){
if(has[i].size()==0) continue;
build(i,has[i]);
ans=max(ans,bfs(has[i][0],i));
}
printf("%lld\n",1ll*ans*ans);
}
/*
5
2 1 2 1 1
1 2
2 3
3 4
4 5
*/

京公网安备 11010502036488号