比较明显的点分治问题,考虑如何计算通过点的路径贡献
获得从出发的所有路径,显然经过的所有路径都是由出发的两条路径拼接而来
每条路径存成一个,第一关键字为最大值,第二关键词为最小值
现在就是一个简单的二维偏序问题
一下那么每个位置的最大值一定比前面大
然后使用动态开点权值线段树,最小值为下标,维护每个位置的数量和权值和即可
比如对于第条路径,最大权值为,最小权值为,考虑和前条路径合并的贡献
在线段树上查询一下权值为的个数和权值和即可
当作为路径中最小值时,答案是
当其他路径作为最小值时,答案是
动态开点权值线段树常数略大,需要写成类型的版本
也可以写离散化的树状数组,就不需要卡常数
实现的时候需要容斥一下,毕竟两条路径不能同时在u的一颗子树内拼接
#include <bits/stdc++.h> using namespace std; #define mid (l+r>>1) typedef long long ll; const int maxn = 4e5+10; const int inf = 998244353; const int mod = 998244353; int n,a[maxn]; ll ans; vector<int>vec[maxn]; int siz[maxn],vis[maxn],mx[maxn],root,sumn; typedef pair<int,int>p; p res[maxn]; int top; ll he[maxn<<4]; int shu[maxn<<4],ls[maxn<<4],rs[maxn<<4],id,rot; void insert(int &rt,int l,int r,int v,int sz) { if( !rt ) rt = ++id; if( l>v || r<v ) return; if( l==r && l==v ) { shu[rt] += sz, he[rt] = he[rt]+sz*l; return; } insert( ls[rt],l,mid,v,sz ); insert( rs[rt],mid+1,r,v,sz ); shu[rt] = shu[ls[rt]]+shu[rs[rt]], he[rt] = he[ls[rt]]+he[rs[rt]]; } ll sum,su; void ask(int rt,int l,int r,int L,int R) { if( !rt ) return; if( l>R || r<L ) return; if( l>=L && r<=R ) { su += shu[rt], sum += he[rt]; return; } ask(ls[rt],l,mid,L,R); ask(rs[rt],mid+1,r,L,R); } void getroot(int u,int fa) { siz[u] = 1, mx[u] = 0; for( auto v:vec[u] ) { if( vis[v] || v==fa ) continue; getroot(v,u); siz[u] += siz[v]; mx[u] = max( mx[u],siz[v] ); } mx[u] = max( mx[u],sumn-siz[u] ); if( mx[u]<mx[root] ) root = u; } void dfs(int u,int fa,int mi,int mx) { res[++top] = p( mx,mi ); for(auto v:vec[u] ) { if( v==fa || vis[v] ) continue; dfs( v,u,min(mi,a[v]),max(mx,a[v]) ); } } ll calc(int u,int l,int r) { long long ans = 0; top = 0; dfs( u,u,l,r ); sort( res+1,res+1+top ); for(int i=1;i<=top;i++)//以最小值为下标 { //[1,i-1]的最小值有多少小于自己的 sum = 0, su = 0; ask(rot,0,inf,0,res[i].second-1); sum %= mod; ans = ( ans+1ll*res[i].first*sum%mod )%mod; ans = ( ans+1ll*res[i].first*res[i].second%mod*(i-su)%mod )%mod; insert( rot,0,inf,res[i].second,1 ); } for(int i=1;i<=top;i++) insert( rot,0,inf,res[i].second,-1 ); return ans; } void solve(int u) { vis[u] = 1; ans = ( ans+calc(u,a[u],a[u]) )%mod; for(auto v:vec[u] ) { if( vis[v] ) continue; ans = ( ans-calc(v,min(a[u],a[v]),max(a[u],a[v])) )%mod; sumn = siz[v], mx[root=0] = n+1; getroot(v,0); solve( root ); } } int main() { scanf("%d",&n ); for(int i=1;i<=n;i++) scanf("%d",&a[i] ); for(int i=1;i<n;i++) { int l,r; scanf("%d%d",&l,&r); vec[l].push_back( r ); vec[r].push_back( l ); } sumn = n, mx[root=0] = n+1; getroot(1,0); solve(1); printf("%lld",(ans%mod+mod)%mod ); }