题目描述

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

输入输出格式

输入格式:
第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。

接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。

接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)

接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:

操作1: 1 x y z

操作2: 2 x y

操作3: 3 x z

操作4: 4 x

输出格式:
输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模)

输入输出样例

输入样例#1: 复制
5 5 2 24
7 3 7 8 0
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3
输出样例#1: 复制
2
21

非AC代码,用于表明更一般的树链剖分模板

#include "bits/stdc++.h"
#define pb push_back
#define ls l,m,now<<1
#define rs m+1,r,now<<1|1
using namespace std;
typedef long long ll;
typedef pair<int,int> pr;
inline int read() {
    int x=0;
    char c=getchar();
    while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-'0', c=getchar();
    return x;
}

const int maxn = 5e5+10;
const int mod = 1e9+7;
const double eps = 1e-9;

int head[maxn], to[maxn*2], nxt[maxn*2], tot;

int N, M, R;
int val[maxn];
int d[maxn], f[maxn], s[maxn], son[maxn];
int top[maxn], id[maxn], rk[maxn], ID;

inline void add_edge(int u, int v) {
    ++tot, to[tot]=v, nxt[tot]=head[u], head[u]=tot;
    ++tot, to[tot]=u, nxt[tot]=head[v], head[v]=tot;
}

void dfs1(int u, int p) {
    f[u]=p;
    s[u]=1;
    ll ss=0;
    for(int i=head[u]; i; i=nxt[i]) {
        if(to[i]!=p) {
            d[to[i]]=d[u]+1;
            dfs1(to[i],u);
            s[u]+=s[to[i]];
            if(s[to[i]]>ss) ss=s[to[i]], son[u]=to[i];
        }
    }
}

void dfs2(int u, int p, int t) {
    id[u]=++ID;
    rk[ID]=u;
    top[u]=t;
    if(son[u]) dfs2(son[u],u,t);
    for(int i=head[u]; i; i=nxt[i]) {
        if(to[i]!=son[u]&&to[i]!=p) dfs2(to[i],u,to[i]);
    }
}

int node[maxn<<2], lazy[maxn<<2];

void build(int l, int r, int now) {
    if(l==r) {
        node[now]=val[rk[l]];
        return;
    }
    int m=(l+r)/2;
    build(ls);
    build(rs);
    node[now]=(node[now<<1]+node[now<<1|1]);
}

void push_down(int ln, int rn, int now) {
    if(lazy[now]) {
        int p=lazy[now];
        lazy[now]=0;
        lazy[now<<1]+=p;
        lazy[now<<1|1]+=p;
        node[now<<1]+=p*ln;
        node[now<<1|1]+=p*rn;
    }
}

void update(int L, int R, int d, int l, int r, int now) {
    if(L<=l&&r<=R) {
        node[now]+=(r-l+1)*d;
        lazy[now]+=d;
        return;
    }
    int m=(l+r)/2;
    push_down(m-l+1,r-m,now);
    if(L<=m) update(L,R,d,ls);
    if(R>m) update(L,R,d,rs);
    node[now]=node[now<<1]+node[now<<1|1];
}

int query(int L, int R, int l, int r, int now) {
    if(L<=l&&r<=R) {
        return node[now];
    }
    int m=(l+r)/2;
    push_down(m-l+1,r-m,now);
    int ans=0;
    if(L<=m) ans+=query(L,R,ls);
    if(R>m) ans+=query(L,R,rs);
    return ans;
}

void update1(int x, int y, int dd) {
    int tx=top[x], ty=top[y];
    while(tx!=ty) {
        if(d[tx]>=d[ty]) update(id[tx],id[x],dd,1,N,1), x=f[tx], tx=top[x];
        else update(id[ty],id[y],dd,1,N,1), y=f[ty], ty=top[y];
    }
    if(d[x]>=d[y]) update(id[y],id[x],dd,1,N,1);
    else update(id[x],id[y],dd,1,N,1);
}

int sum(int x, int y) {
    int ans=0;
    int tx=top[x], ty=top[y];
    while(tx!=ty) {
        if(d[tx]>=d[ty]) ans+=query(id[tx],id[x],1,N,1), x=f[tx], tx=top[x];
        else ans+=query(id[ty],id[y],1,N,1), y=f[ty], ty=top[y];
    }
    if(d[x]>=d[y]) ans+=query(id[y],id[x],1,N,1);
    else ans+=query(id[x],id[y],1,N,1);
    return ans;
}

int main() {
    //ios::sync_with_stdio(false);
    N=read(), M=read(), R=read();
    for(int i=1; i<=N; ++i) val[i]=read();
    for(int i=1; i<N; ++i) add_edge(read(),read());
    dfs1(R,0);
    dfs2(R,0,R);
    //for(int i=1; i<=N; ++i) printf("%d:%d %d %d %d\n", i, f[i], id[i], s[i], d[i]);
    build(1,N,1);
    while(M--) {
        int q=read();
        if(q==1) update1(read(),read(),read());
        else if(q==2) printf("%d\n", sum(read(),read()));
        else if(q==3) {
            int x=read(), z=read();
            update(id[x],id[x]+s[x]-1,z,1,N,1);
        }
        else if(q==4) {
            int x=read();
            printf("%d\n", query(id[x],id[x]+s[x]-1,1,N,1));
        }
    }
}