E题题解

题目大意:

题目挺短好理解,这里就不解释了

前置知识:

树形dp、线段树合并、树上启发式合并

解题思路:

对于任意的 能到达 的状态数量为, 到达 也是这个值,因此总值为

题目转化为

这种树上求值问题,首先就要往树形dp上想。

不妨让 节点为整棵树的根

首先来看到根节点为 的子树 ,思考如何求出 能到达 的状态数量,其中

对于每个 的儿子 (子树中直接相连的点),根节点为 的子树我们已经统计好了

只需要求 的值,将其拆成两部分来看

先来看第一部分,考虑把上面的式子绝对值拆掉,如果我们能知道 ,那么总值为两部分:

一部分是 ,总值为 ,一部分是 ,考虑维护两个值:

使用权值线段树维护,其中 值为下标(离散也行,动态开点也行,这里使用动态开点)

struct XD_tree{//ll 是 long long ,用了define
    ll pl,pr,num,sum,lz=1;//pl,pr是左右子树的下标,lz是懒标记
}tr[MXN*LG];//MXN=200002,LG=20
ll rt[MXN],trt;//rt是每个子树的根,trt是用于开点的
ll req[MXN*LG],ret;//记录回收的点
inline ll renew(){ll p=ret?req[ret--]:++trt;tr[p]=tr[0];return p;}//获取新点,使用数组的记得赋初值
inline void reuse(ll &p){if(p) req[++ret]=p;p=0;}//回收无用的点

上传操作:

inline void ptup(ll p){
    XD_tree pl=tr[tr[p].pl],pr=tr[tr[p].pr];
    tr[p].num=(pl.num+pr.num)%MO;//MO是模数998244353
    tr[p].sum=(pl.sum+pr.sum)%MO;
}

区间查询:

inline void addp(par &x,par y){//par的加法运算
    x={(x.st+y.st)%MO,(x.nd+y.nd)%MO};//st是first,nd是second
}
par ask(ll p,ll l,ll r,ll x,ll y){//par是pair<long long,long long>,用了define
    if(!p||x>y) return {0,0};
    if(x<=l&&r<=y) return {tr[p].num,tr[p].sum};
    ll mid=l+r>>1;
    par ans={0,0};
    ptdn(p);
    if(x<=mid) addp(ans,ask(tr[p].pl,l,mid,x,y));
    if(y>mid) addp(ans,ask(tr[p].pr,mid+1,r,x,y));
    return ans;
}

那么总值为

我们从叶子节点往根节点统计,那么从儿子节点 传递时,每个节点的 ,都要加 才能变为 ,故所有 值除以 值也除以 ,引入区间乘操作(乘 的逆元),要使用懒标记,chg1(rt[u],1,m,1,m,MO+1>>1) 的最大值,可替换为

inline void doit(ll p,ll k){//单点乘上k
    tr[p].num=tr[p].num*k%MO;
    tr[p].sum=tr[p].sum*k%MO;
    tr[p].lz=tr[p].lz*k%MO;
}
inline void ptdn(ll p){//下传操作
    if(tr[p].pl) doit(tr[p].pl,tr[p].lz);
    if(tr[p].pr) doit(tr[p].pr,tr[p].lz);
    tr[p].lz=1;
}
void chg1(ll &p,ll l,ll r,ll x,ll y,ll k){
    if(!p||x>y) return;
    if(x<=l&&r<=y){doit(p,k);return;}
    ll mid=l+r>>1;
    ptdn(p);
    if(x<=mid) chg1(tr[p].pl,l,mid,x,y,k);
    if(y>mid) chg1(tr[p].pr,mid+1,r,x,y,k);
    ptup(p);
}

不仅如此,儿子节点 也应统计进去,状态数为 chg2(rt[u],1,m,a[u],f2[n-1]) 代表

void chg2(ll &p,ll l,ll r,ll x,ll k){
    if(!p) p=renew();
    if(l==r){
        tr[p].num=(tr[p].num+k)%MO;
        tr[p].sum=(tr[p].sum+x*k)%MO;
        return;
    }
    ll mid=l+r>>1;
    ptdn(p);
    if(x<=mid) chg2(tr[p].pl,l,mid,x,k);
    else chg2(tr[p].pr,mid+1,r,x,k);
    ptup(p);
}

可是节点 有很多儿子啊,这时候就要用到线段树合并了,将所有儿子合并起来,mge(rt[u],rt[v],1,m)(这里的 对应讲解里的

既然上面的第一部分 已经完成,接下来计算第二部分

我有一个大胆的想法,每次合并一个儿子子树前,先枚举儿子子树 每个点,求到已经合并的儿子子树的值。

当然直接暴力肯定不行,如何优化?考虑树上启发式合并,先将最重的儿子子树合并,这些点将不用枚举,再枚举其他儿子子树。可以证明复杂度是 级别的。

最终的树形dp:

inline void mak(ll rt,ll w,ll k){//统计数量,rt是某个节点的根,w是某个a[u],k的解释在dfs2
    par val=ask(rt,1,m,1,w),sal=ask(rt,1,m,1,m);//val是{num(1~w),sum(1~w)},sal是{num(1~m),sum(1~m)}
    ans+=((2*val.st-sal.st+MO)*w%MO+sal.nd-2*val.nd+2*MO)%MO*k;
    ans%=MO;
}
void dfs2(ll u,ll fa,ll rot){
    mak(rt[rot],a[u],v2[dep[u]-dep[rot]+1]);//v2是 2^(u到rot的距离+1) 的逆元,这里加1是因为dfs里还没有chg1(rt[u],1,m,1,m,MO+1>>1)将全部值除以2
    for(ll i=hd[u];i;i=e[i].nt){
        ll v=e[i].v;
        if(v==fa) continue;
        dfs2(v,u,rot);
    }
}
void dfs(ll u,ll fa){//树形dp搜索
    sz[u]=1,dep[u]=dep[fa]+1;//sz是子树大小,dep是节点深度,用于求部分距离(在dfs2中使用)
    ll mx=0;
    for(ll i=hd[u];i;i=e[i].nt){
        ll v=e[i].v;
        if(v==fa) continue;
        dfs(v,u);
        sz[u]+=sz[v];
        if(sz[v]>mx) mx=sz[v],son[u]=v;//找出重儿子
    }
    mge(rt[u],rt[son[u]],1,m);//先将重儿子合并
    for(ll i=hd[u];i;i=e[i].nt){
        ll v=e[i].v;
        if(v==fa||v==son[u]) continue;
        dfs2(v,u,u);//枚举轻儿子每个点
        mge(rt[u],rt[v],1,m);//再将轻儿子合并
    }
    chg1(rt[u],1,m,1,m,MO+1>>1);//将所有值除以2
    mak(rt[u],a[u],1);//统计u节点到其子树所有点的值
    chg2(rt[u],1,m,a[u],f2[n]);//这里先将u节点统计进去了,因为父亲节会有全部除以2,故这里的状态数不是2^(n-1),而是2^n
}

最终代码

#include<bits/stdc++.h>
#define ll long long
#define par pair<ll,ll>
#define st first
#define nd second
#define MO 998244353ll
#define MXN 200002
#define LG 20
using namespace std;
inline void rd(ll &x){x=0;short f=1;char c=getchar();while((c<'0'||c>'9')&&c!='-') c=getchar();if(c=='-') f=-1,c=getchar();while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar();x*=f;}
inline void pt(ll x){if(x<0) putchar('-'),x=-x;if(x>9) pt(x/10);putchar(x%10+'0');}
struct XD_tree{
    ll pl,pr,num,sum,lz=1;
}tr[MXN*LG];
ll rt[MXN],trt;
ll req[MXN*LG],ret;
struct bzm{
    ll v,nt;
}e[MXN*2];
ll hd[MXN],ett;
ll f2[MXN],v2[MXN];
ll sz[MXN],dep[MXN],son[MXN];
ll T=1,n,m,a[MXN],ans;
inline ll renew(){ll p=ret?req[ret--]:++trt;tr[p]=tr[0];return p;}
inline void reuse(ll &p){if(p) req[++ret]=p;p=0;}
inline void adde(ll u,ll v){
    e[++ett]={v,hd[u]};
    hd[u]=ett;
}
inline void addp(par &x,par y){
    x={(x.st+y.st)%MO,(x.nd+y.nd)%MO};
}
inline void ptup(ll p){
    XD_tree pl=tr[tr[p].pl],pr=tr[tr[p].pr];
    tr[p].num=(pl.num+pr.num)%MO;
    tr[p].sum=(pl.sum+pr.sum)%MO;
}
inline void doit(ll p,ll k){
    tr[p].num=tr[p].num*k%MO;
    tr[p].sum=tr[p].sum*k%MO;
    tr[p].lz=tr[p].lz*k%MO;
}
inline void ptdn(ll p){
    if(tr[p].pl) doit(tr[p].pl,tr[p].lz);
    if(tr[p].pr) doit(tr[p].pr,tr[p].lz);
    tr[p].lz=1;
}
void chg1(ll &p,ll l,ll r,ll x,ll y,ll k){
    if(!p||x>y) return;
    if(x<=l&&r<=y){doit(p,k);return;}
    ll mid=l+r>>1;
    ptdn(p);
    if(x<=mid) chg1(tr[p].pl,l,mid,x,y,k);
    if(y>mid) chg1(tr[p].pr,mid+1,r,x,y,k);
    ptup(p);
}
void chg2(ll &p,ll l,ll r,ll x,ll k){
    if(!p) p=renew();
    if(l==r){
        tr[p].num=(tr[p].num+k)%MO;
        tr[p].sum=(tr[p].sum+x*k)%MO;
        return;
    }
    ll mid=l+r>>1;
    ptdn(p);
    if(x<=mid) chg2(tr[p].pl,l,mid,x,k);
    else chg2(tr[p].pr,mid+1,r,x,k);
    ptup(p);
}
void mge(ll &p,ll &q,ll l,ll r){
    if(!p||!q){p=p+q;return;}
    tr[p].num=(tr[p].num+tr[q].num)%MO;
    tr[p].sum=(tr[p].sum+tr[q].sum)%MO;
    ptdn(p),ptdn(q);
    ll mid=l+r>>1;
    mge(tr[p].pl,tr[q].pl,l,mid);
    mge(tr[p].pr,tr[q].pr,mid+1,r);
    reuse(q);
}
par ask(ll p,ll l,ll r,ll x,ll y){
    if(!p||x>y) return {0,0};
    if(x<=l&&r<=y) return {tr[p].num,tr[p].sum};
    ll mid=l+r>>1;
    par ans={0,0};
    ptdn(p);
    if(x<=mid) addp(ans,ask(tr[p].pl,l,mid,x,y));
    if(y>mid) addp(ans,ask(tr[p].pr,mid+1,r,x,y));
    return ans;
}
inline void mak(ll rt,ll w,ll k){
    par val=ask(rt,1,m,1,w),sal=ask(rt,1,m,1,m);
    ans+=((2*val.st-sal.st+MO)*w%MO+sal.nd-2*val.nd+2*MO)%MO*k;
    ans%=MO;
}
void dfs2(ll u,ll fa,ll rot){
    mak(rt[rot],a[u],v2[dep[u]-dep[rot]+1]);
    for(ll i=hd[u];i;i=e[i].nt){
        ll v=e[i].v;
        if(v==fa) continue;
        dfs2(v,u,rot);
    }
}
void dfs(ll u,ll fa){
    sz[u]=1,dep[u]=dep[fa]+1;
    ll mx=0;
    for(ll i=hd[u];i;i=e[i].nt){
        ll v=e[i].v;
        if(v==fa) continue;
        dfs(v,u);
        sz[u]+=sz[v];
        if(sz[v]>mx) mx=sz[v],son[u]=v;
    }
    mge(rt[u],rt[son[u]],1,m);
    for(ll i=hd[u];i;i=e[i].nt){
        ll v=e[i].v;
        if(v==fa||v==son[u]) continue;
        dfs2(v,u,u);
        mge(rt[u],rt[v],1,m);
    }
    chg1(rt[u],1,m,1,m,MO+1>>1);
    mak(rt[u],a[u],1);
    chg2(rt[u],1,m,a[u],f2[n]);
}
void solve(){
    rd(n);
    v2[0]=f2[0]=1;
    for(ll i=1;i<=n;i++)
        f2[i]=(f2[i-1]<<1)%MO,
        v2[i]=(v2[i-1]*(MO+1>>1))%MO;
    for(ll i=1,u,v;i<n;i++)
        rd(u),rd(v),adde(u,v),adde(v,u);
    for(ll i=1;i<=n;i++)
        rd(a[i]),m=max(m,a[i]);
    dfs(1,0);
    pt(ans);
}int main(){while(T--) solve();}