题目描述
给你一棵根为1的有N个节点的树,以及Q次操作。
每次操作诸如:
1 x y:将节点x所在的子树的所有节点的权值加上y
2 x:询问x所在子树的所有节点的权值的平方和,答案模23333后输出
输入描述:
第一行两个整数N,Q
第二行N个整数,第i个表示节点i的初始权值
接下来N-1行每行两个整数u,v,表示u和v之间存在一条树边
接下来Q行每行一个操作,格式如题目描述
输出描述:
对于每个询问操作,输出一行一个整数,表示答案在模23333后的结果
示例1
输入
复制
5 5
0 0 0 0 0
1 2
1 3
3 4
3 5
1 1 3
1 3 7
1 4 5
1 5 6
2 1
输出
复制
599
备注:
- 数据范围
一共有10个测试点,对于第i个测试点保证,N=10000 x i
对于100 %100%的数据,有1 ≤ N,Q,y ≤ 105,1 ≤ x ≤ N - 注
平方和的意思是:a2+b2+c2
(a+b+c)^2
是和的平方
比较裸的线段树维护区间平方和+dfs序。
AC代码:
#pragma GCC optimize(2)
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+10,mod=23333;
int n,q,st[N],ed[N],a[N],cnt,vis[N];
int head[N],nex[N<<1],to[N<<1],tot;
struct node{
int l,r,sum,res,add;
}t[N<<2];
inline void add(int a,int b){
to[++tot]=b; nex[tot]=head[a]; head[a]=tot;
}
void dfs(int x,int f){
st[x]=++cnt; vis[cnt]=x;
for(int i=head[x];i;i=nex[i]) if(to[i]!=f) dfs(to[i],x);
ed[x]=cnt;
}
inline void push_up(int p){
t[p].res=(t[p<<1].res+t[p<<1|1].res)%mod;
t[p].sum=(t[p<<1].sum+t[p<<1|1].sum)%mod;
}
inline void push_down(int p){
if(t[p].add){
int v=t[p].add; t[p].add=0;
t[p<<1].add=(t[p<<1].add+v)%mod;
t[p<<1|1].add=(t[p<<1|1].add+v)%mod;
int ll=t[p<<1].r-t[p<<1].l+1; int lr=t[p<<1|1].r-t[p<<1|1].l+1;
t[p<<1].res=(t[p<<1].res+ll*v*v%mod+2*t[p<<1].sum*v%mod)%mod;
t[p<<1|1].res=(t[p<<1|1].res+lr*v*v%mod+2*t[p<<1|1].sum*v%mod)%mod;
t[p<<1].sum=(t[p<<1].sum+ll*v)%mod;
t[p<<1|1].sum=(t[p<<1|1].sum+lr*v)%mod;
}
}
void build(int p,int l,int r){
t[p].l=l; t[p].r=r;
if(l==r){
t[p].sum=a[vis[l]]; t[p].res=t[p].sum*t[p].sum%mod; return ;
}
int mid=l+r>>1; build(p<<1,l,mid); build(p<<1|1,mid+1,r);
push_up(p);
}
void change(int p,int l,int r,int v){
if(t[p].l==l&&t[p].r==r){
t[p].res=(t[p].res+(r-l+1)*v*v%mod+2*t[p].sum*v%mod)%mod;
t[p].sum=(t[p].sum+(r-l+1)*v%mod)%mod; t[p].add=(t[p].add+v)%mod;
return ;
}
push_down(p); int mid=t[p].l+t[p].r>>1;
if(r<=mid) change(p<<1,l,r,v);
else if(l>mid) change(p<<1|1,l,r,v);
else change(p<<1,l,mid,v),change(p<<1|1,mid+1,r,v);
push_up(p);
}
int ask(int p,int l,int r){
if(t[p].l==l&&t[p].r==r) return t[p].res%mod;
push_down(p); int mid=t[p].l+t[p].r>>1;
if(r<=mid) return ask(p<<1,l,r);
else if(l>mid) return ask(p<<1|1,l,r);
else return (ask(p<<1,l,mid)+ask(p<<1|1,mid+1,r))%mod;
}
signed main(){
scanf("%lld %lld",&n,&q);
for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
for(int i=1,a,b;i<n;i++) scanf("%lld %lld",&a,&b),add(a,b),add(b,a);
dfs(1,0); build(1,1,n);
while(q--){
int op,x,y; scanf("%lld %lld",&op,&x);
if(op==1) scanf("%lld",&y),change(1,st[x],ed[x],y);
else printf("%lld\n",ask(1,st[x],ed[x]));
}
return 0;
}