E题题解
题目大意:
题目挺短好理解,这里就不解释了
前置知识:
树形dp、线段树合并、树上启发式合并
解题思路:
对于任意的 ,
能到达
的状态数量为,
,
到达
也是这个值,因此总值为
。
题目转化为
这种树上求值问题,首先就要往树形dp上想。
不妨让 节点为整棵树的根
首先来看到根节点为 的子树
,思考如何求出
能到达
的状态数量,其中
。
对于每个 的儿子
(子树中直接相连的点),根节点为
的子树我们已经统计好了
只需要求 到
的值,将其拆成两部分来看
先来看第一部分,考虑把上面的式子绝对值拆掉,如果我们能知道
,那么总值为两部分:
一部分是 ,总值为
,一部分是
,
,考虑维护两个值:
使用权值线段树维护,其中 值为下标(离散也行,动态开点也行,这里使用动态开点)
struct XD_tree{//ll 是 long long ,用了define
ll pl,pr,num,sum,lz=1;//pl,pr是左右子树的下标,lz是懒标记
}tr[MXN*LG];//MXN=200002,LG=20
ll rt[MXN],trt;//rt是每个子树的根,trt是用于开点的
ll req[MXN*LG],ret;//记录回收的点
inline ll renew(){ll p=ret?req[ret--]:++trt;tr[p]=tr[0];return p;}//获取新点,使用数组的记得赋初值
inline void reuse(ll &p){if(p) req[++ret]=p;p=0;}//回收无用的点
上传操作:
inline void ptup(ll p){
XD_tree pl=tr[tr[p].pl],pr=tr[tr[p].pr];
tr[p].num=(pl.num+pr.num)%MO;//MO是模数998244353
tr[p].sum=(pl.sum+pr.sum)%MO;
}
区间查询:
inline void addp(par &x,par y){//par的加法运算
x={(x.st+y.st)%MO,(x.nd+y.nd)%MO};//st是first,nd是second
}
par ask(ll p,ll l,ll r,ll x,ll y){//par是pair<long long,long long>,用了define
if(!p||x>y) return {0,0};
if(x<=l&&r<=y) return {tr[p].num,tr[p].sum};
ll mid=l+r>>1;
par ans={0,0};
ptdn(p);
if(x<=mid) addp(ans,ask(tr[p].pl,l,mid,x,y));
if(y>mid) addp(ans,ask(tr[p].pr,mid+1,r,x,y));
return ans;
}
那么总值为
我们从叶子节点往根节点统计,那么从儿子节点 往
传递时,每个节点的
,都要加
才能变为
,故所有
值除以
,
值也除以
,引入区间乘操作(乘
的逆元),要使用懒标记,
chg1(rt[u],1,m,1,m,MO+1>>1)
( 是
的最大值,可替换为
)
inline void doit(ll p,ll k){//单点乘上k
tr[p].num=tr[p].num*k%MO;
tr[p].sum=tr[p].sum*k%MO;
tr[p].lz=tr[p].lz*k%MO;
}
inline void ptdn(ll p){//下传操作
if(tr[p].pl) doit(tr[p].pl,tr[p].lz);
if(tr[p].pr) doit(tr[p].pr,tr[p].lz);
tr[p].lz=1;
}
void chg1(ll &p,ll l,ll r,ll x,ll y,ll k){
if(!p||x>y) return;
if(x<=l&&r<=y){doit(p,k);return;}
ll mid=l+r>>1;
ptdn(p);
if(x<=mid) chg1(tr[p].pl,l,mid,x,y,k);
if(y>mid) chg1(tr[p].pr,mid+1,r,x,y,k);
ptup(p);
}
不仅如此,儿子节点 也应统计进去,状态数为
,
chg2(rt[u],1,m,a[u],f2[n-1])
( 代表
)
void chg2(ll &p,ll l,ll r,ll x,ll k){
if(!p) p=renew();
if(l==r){
tr[p].num=(tr[p].num+k)%MO;
tr[p].sum=(tr[p].sum+x*k)%MO;
return;
}
ll mid=l+r>>1;
ptdn(p);
if(x<=mid) chg2(tr[p].pl,l,mid,x,k);
else chg2(tr[p].pr,mid+1,r,x,k);
ptup(p);
}
可是节点 有很多儿子啊,这时候就要用到线段树合并了,将所有儿子合并起来,
mge(rt[u],rt[v],1,m)
(这里的 对应讲解里的
)
既然上面的第一部分
已经完成,接下来计算第二部分 } * |a_x-a_v|,u∈sub(x)-sub(y_i),u≠x&preview=true)
我有一个大胆的想法,每次合并一个儿子子树前,先枚举儿子子树 每个点,求到已经合并的儿子子树的值。
当然直接暴力肯定不行,如何优化?考虑树上启发式合并,先将最重的儿子子树合并,这些点将不用枚举,再枚举其他儿子子树。可以证明复杂度是 级别的。
最终的树形dp:
inline void mak(ll rt,ll w,ll k){//统计数量,rt是某个节点的根,w是某个a[u],k的解释在dfs2
par val=ask(rt,1,m,1,w),sal=ask(rt,1,m,1,m);//val是{num(1~w),sum(1~w)},sal是{num(1~m),sum(1~m)}
ans+=((2*val.st-sal.st+MO)*w%MO+sal.nd-2*val.nd+2*MO)%MO*k;
ans%=MO;
}
void dfs2(ll u,ll fa,ll rot){
mak(rt[rot],a[u],v2[dep[u]-dep[rot]+1]);//v2是 2^(u到rot的距离+1) 的逆元,这里加1是因为dfs里还没有chg1(rt[u],1,m,1,m,MO+1>>1)将全部值除以2
for(ll i=hd[u];i;i=e[i].nt){
ll v=e[i].v;
if(v==fa) continue;
dfs2(v,u,rot);
}
}
void dfs(ll u,ll fa){//树形dp搜索
sz[u]=1,dep[u]=dep[fa]+1;//sz是子树大小,dep是节点深度,用于求部分距离(在dfs2中使用)
ll mx=0;
for(ll i=hd[u];i;i=e[i].nt){
ll v=e[i].v;
if(v==fa) continue;
dfs(v,u);
sz[u]+=sz[v];
if(sz[v]>mx) mx=sz[v],son[u]=v;//找出重儿子
}
mge(rt[u],rt[son[u]],1,m);//先将重儿子合并
for(ll i=hd[u];i;i=e[i].nt){
ll v=e[i].v;
if(v==fa||v==son[u]) continue;
dfs2(v,u,u);//枚举轻儿子每个点
mge(rt[u],rt[v],1,m);//再将轻儿子合并
}
chg1(rt[u],1,m,1,m,MO+1>>1);//将所有值除以2
mak(rt[u],a[u],1);//统计u节点到其子树所有点的值
chg2(rt[u],1,m,a[u],f2[n]);//这里先将u节点统计进去了,因为父亲节会有全部除以2,故这里的状态数不是2^(n-1),而是2^n
}
最终代码
#include<bits/stdc++.h>
#define ll long long
#define par pair<ll,ll>
#define st first
#define nd second
#define MO 998244353ll
#define MXN 200002
#define LG 20
using namespace std;
inline void rd(ll &x){x=0;short f=1;char c=getchar();while((c<'0'||c>'9')&&c!='-') c=getchar();if(c=='-') f=-1,c=getchar();while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar();x*=f;}
inline void pt(ll x){if(x<0) putchar('-'),x=-x;if(x>9) pt(x/10);putchar(x%10+'0');}
struct XD_tree{
ll pl,pr,num,sum,lz=1;
}tr[MXN*LG];
ll rt[MXN],trt;
ll req[MXN*LG],ret;
struct bzm{
ll v,nt;
}e[MXN*2];
ll hd[MXN],ett;
ll f2[MXN],v2[MXN];
ll sz[MXN],dep[MXN],son[MXN];
ll T=1,n,m,a[MXN],ans;
inline ll renew(){ll p=ret?req[ret--]:++trt;tr[p]=tr[0];return p;}
inline void reuse(ll &p){if(p) req[++ret]=p;p=0;}
inline void adde(ll u,ll v){
e[++ett]={v,hd[u]};
hd[u]=ett;
}
inline void addp(par &x,par y){
x={(x.st+y.st)%MO,(x.nd+y.nd)%MO};
}
inline void ptup(ll p){
XD_tree pl=tr[tr[p].pl],pr=tr[tr[p].pr];
tr[p].num=(pl.num+pr.num)%MO;
tr[p].sum=(pl.sum+pr.sum)%MO;
}
inline void doit(ll p,ll k){
tr[p].num=tr[p].num*k%MO;
tr[p].sum=tr[p].sum*k%MO;
tr[p].lz=tr[p].lz*k%MO;
}
inline void ptdn(ll p){
if(tr[p].pl) doit(tr[p].pl,tr[p].lz);
if(tr[p].pr) doit(tr[p].pr,tr[p].lz);
tr[p].lz=1;
}
void chg1(ll &p,ll l,ll r,ll x,ll y,ll k){
if(!p||x>y) return;
if(x<=l&&r<=y){doit(p,k);return;}
ll mid=l+r>>1;
ptdn(p);
if(x<=mid) chg1(tr[p].pl,l,mid,x,y,k);
if(y>mid) chg1(tr[p].pr,mid+1,r,x,y,k);
ptup(p);
}
void chg2(ll &p,ll l,ll r,ll x,ll k){
if(!p) p=renew();
if(l==r){
tr[p].num=(tr[p].num+k)%MO;
tr[p].sum=(tr[p].sum+x*k)%MO;
return;
}
ll mid=l+r>>1;
ptdn(p);
if(x<=mid) chg2(tr[p].pl,l,mid,x,k);
else chg2(tr[p].pr,mid+1,r,x,k);
ptup(p);
}
void mge(ll &p,ll &q,ll l,ll r){
if(!p||!q){p=p+q;return;}
tr[p].num=(tr[p].num+tr[q].num)%MO;
tr[p].sum=(tr[p].sum+tr[q].sum)%MO;
ptdn(p),ptdn(q);
ll mid=l+r>>1;
mge(tr[p].pl,tr[q].pl,l,mid);
mge(tr[p].pr,tr[q].pr,mid+1,r);
reuse(q);
}
par ask(ll p,ll l,ll r,ll x,ll y){
if(!p||x>y) return {0,0};
if(x<=l&&r<=y) return {tr[p].num,tr[p].sum};
ll mid=l+r>>1;
par ans={0,0};
ptdn(p);
if(x<=mid) addp(ans,ask(tr[p].pl,l,mid,x,y));
if(y>mid) addp(ans,ask(tr[p].pr,mid+1,r,x,y));
return ans;
}
inline void mak(ll rt,ll w,ll k){
par val=ask(rt,1,m,1,w),sal=ask(rt,1,m,1,m);
ans+=((2*val.st-sal.st+MO)*w%MO+sal.nd-2*val.nd+2*MO)%MO*k;
ans%=MO;
}
void dfs2(ll u,ll fa,ll rot){
mak(rt[rot],a[u],v2[dep[u]-dep[rot]+1]);
for(ll i=hd[u];i;i=e[i].nt){
ll v=e[i].v;
if(v==fa) continue;
dfs2(v,u,rot);
}
}
void dfs(ll u,ll fa){
sz[u]=1,dep[u]=dep[fa]+1;
ll mx=0;
for(ll i=hd[u];i;i=e[i].nt){
ll v=e[i].v;
if(v==fa) continue;
dfs(v,u);
sz[u]+=sz[v];
if(sz[v]>mx) mx=sz[v],son[u]=v;
}
mge(rt[u],rt[son[u]],1,m);
for(ll i=hd[u];i;i=e[i].nt){
ll v=e[i].v;
if(v==fa||v==son[u]) continue;
dfs2(v,u,u);
mge(rt[u],rt[v],1,m);
}
chg1(rt[u],1,m,1,m,MO+1>>1);
mak(rt[u],a[u],1);
chg2(rt[u],1,m,a[u],f2[n]);
}
void solve(){
rd(n);
v2[0]=f2[0]=1;
for(ll i=1;i<=n;i++)
f2[i]=(f2[i-1]<<1)%MO,
v2[i]=(v2[i-1]*(MO+1>>1))%MO;
for(ll i=1,u,v;i<n;i++)
rd(u),rd(v),adde(u,v),adde(v,u);
for(ll i=1;i<=n;i++)
rd(a[i]),m=max(m,a[i]);
dfs(1,0);
pt(ans);
}int main(){while(T--) solve();}