题目
题解
边分治
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=200002;
struct node{
int to,ne,w;
}e[N<<1];
struct kk{
int v,l;
}t[2][N];
int n,nn,i,v[N],x,y,h[N],tot,sz[N],rt,sum,mx,c[2];
ll ans;
vector<int>a[N];
bool vis[N];
inline char gc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int rd(){
int x=0,fl=1;char ch=gc();
for(;ch<48||ch>57;ch=gc())if(ch=='-')fl=-1;
for(;48<=ch&&ch<=57;ch=gc())x=(x<<3)+(x<<1)+(ch^48);
return x*fl;
}
void add(int x,int y,int z){e[++tot]=(node){y,h[x],z},h[x]=tot;}
bool cmp(kk a,kk b){return a.v>b.v;}
void dfs1(int u,int fa){
for (int i=h[u],v;i;i=e[i].ne)
if ((v=e[i].to)!=fa) a[u].push_back(v),dfs1(v,u);
}
void rebuild(){
tot=1,memset(h,0,(n+1)<<2);
for (int i=1;i<=n;i++){
int sz=a[i].size();
if (sz<=2)
for (int j=0;j<sz;j++) add(i,a[i][j],a[i][j]<=nn),add(a[i][j],i,a[i][j]<=nn);
else{
int o1=++n,o2=++n;
v[o1]=v[o2]=v[i];
add(i,o1,0),add(o1,i,0),add(i,o2,0),add(o2,i,0);
for (int j=0;j<sz;j++) a[j&1?o1:o2].push_back(a[i][j]);
}
}
}
void getrt(int u,int fa){
sz[u]=1;
for (int i=h[u],v;i;i=e[i].ne)
if ((v=e[i].to)!=fa && !vis[i>>1]){
getrt(v,u);
sz[u]+=sz[v];
int tmp=max(sz[v],sum-sz[v]);
if (tmp<mx) mx=tmp,rt=i;
}
}
void dfs2(int o,int u,int fa,int len,int val){
val=min(val,v[u]),t[o][c[o]++]=(kk){val,len};
for (int i=h[u],v;i;i=e[i].ne)
if ((v=e[i].to)!=fa && !vis[i>>1]) dfs2(o,v,u,len+e[i].w,val);
}
void solve(int u,int p){
mx=1e9,sum=p,getrt(u,0);
if (mx==1e9) return;
int now=rt;
c[0]=c[1]=0,vis[now>>1]=1;
dfs2(0,e[now].to,0,0,1e9);
dfs2(1,e[now^1].to,0,0,1e9);
sort(t[0],t[0]+c[0],cmp);
sort(t[1],t[1]+c[1],cmp);
for (int i=0,j=0,len=-1e9;i<c[0];i++){
for (;j<c[1] && t[1][j].v>=t[0][i].v;j++) len=max(len,t[1][j].l);
ans=max(ans,1ll*t[0][i].v*(len+e[now].w+t[0][i].l+1));
}
for (int i=0,j=0,len=-1e9;i<c[1];i++){
for (;j<c[0] && t[0][j].v>=t[1][i].v;j++) len=max(len,t[0][j].l);
ans=max(ans,1ll*t[1][i].v*(len+e[now].w+t[1][i].l+1));
}
int SZ=sz[e[now].to];
solve(e[now].to,SZ),solve(e[now^1].to,p-SZ);
}
int main(){
n=nn=rd();
for (i=1;i<=n;i++) v[i]=rd();
for (i=1;i<n;i++) x=rd(),y=rd(),add(x,y,1),add(y,x,1);
dfs1(1,0);
rebuild();
solve(1,n);
printf("%lld",ans);
}