题目描述
给你一棵根为1的有N个节点的树,以及Q次操作。
每次操作诸如:
1 x y:将节点x所在的子树的所有节点的权值加上y
2 x:询问x所在子树的所有节点的权值的平方和,答案模23333后输出
输入描述:
第一行两个整数N,Q
第二行N个整数,第i个表示节点i的初始权值
接下来N-1行每行两个整数u,v,表示u和v之间存在一条树边
接下来Q行每行一个操作,格式如题目描述
输出描述:
对于每个询问操作,输出一行一个整数,表示答案在模23333后的结果
示例1
输入
复制
5 5
0 0 0 0 0
1 2
1 3
3 4
3 5
1 1 3
1 3 7
1 4 5
1 5 6
2 1
输出
复制
599
备注:

  1. 数据范围
    一共有10个测试点,对于第i个测试点保证,N=10000 x i
    对于100 %100%的数据,有1 ≤ N,Q,y ≤ 105,1 ≤ x ≤ N

  2. 平方和的意思是:a2+b2+c2
    (a+b+c)^2
    是和的平方

比较裸的线段树维护区间平方和+dfs序。


AC代码:

#pragma GCC optimize(2)
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+10,mod=23333;
int n,q,st[N],ed[N],a[N],cnt,vis[N];
int head[N],nex[N<<1],to[N<<1],tot;
struct node{
	int l,r,sum,res,add;
}t[N<<2];
inline void add(int a,int b){
	to[++tot]=b; nex[tot]=head[a]; head[a]=tot;
}
void dfs(int x,int f){
	st[x]=++cnt;	vis[cnt]=x;
	for(int i=head[x];i;i=nex[i])	if(to[i]!=f)	dfs(to[i],x);
	ed[x]=cnt;
}
inline void push_up(int p){
	t[p].res=(t[p<<1].res+t[p<<1|1].res)%mod;
	t[p].sum=(t[p<<1].sum+t[p<<1|1].sum)%mod;
}
inline void push_down(int p){
	if(t[p].add){
		int v=t[p].add;	t[p].add=0;
		t[p<<1].add=(t[p<<1].add+v)%mod;
		t[p<<1|1].add=(t[p<<1|1].add+v)%mod;
		int ll=t[p<<1].r-t[p<<1].l+1;	int lr=t[p<<1|1].r-t[p<<1|1].l+1;
		t[p<<1].res=(t[p<<1].res+ll*v*v%mod+2*t[p<<1].sum*v%mod)%mod;
		t[p<<1|1].res=(t[p<<1|1].res+lr*v*v%mod+2*t[p<<1|1].sum*v%mod)%mod;
		t[p<<1].sum=(t[p<<1].sum+ll*v)%mod;
		t[p<<1|1].sum=(t[p<<1|1].sum+lr*v)%mod;
	}
}
void build(int p,int l,int r){
	t[p].l=l; t[p].r=r;
	if(l==r){
		t[p].sum=a[vis[l]]; t[p].res=t[p].sum*t[p].sum%mod; return ;
	}
	int mid=l+r>>1;	build(p<<1,l,mid); build(p<<1|1,mid+1,r);
	push_up(p);
}
void change(int p,int l,int r,int v){
	if(t[p].l==l&&t[p].r==r){
		t[p].res=(t[p].res+(r-l+1)*v*v%mod+2*t[p].sum*v%mod)%mod;
		t[p].sum=(t[p].sum+(r-l+1)*v%mod)%mod;	t[p].add=(t[p].add+v)%mod;
		return ;
	}
	push_down(p);	int mid=t[p].l+t[p].r>>1;
	if(r<=mid)	change(p<<1,l,r,v);
	else if(l>mid)	change(p<<1|1,l,r,v);
	else	change(p<<1,l,mid,v),change(p<<1|1,mid+1,r,v);
	push_up(p);
}
int ask(int p,int l,int r){
	if(t[p].l==l&&t[p].r==r)	return t[p].res%mod;
	push_down(p);	int mid=t[p].l+t[p].r>>1;
	if(r<=mid)	return ask(p<<1,l,r);
	else if(l>mid)	return ask(p<<1|1,l,r);
	else return (ask(p<<1,l,mid)+ask(p<<1|1,mid+1,r))%mod;
}
signed main(){
	scanf("%lld %lld",&n,&q);
	for(int i=1;i<=n;i++)	scanf("%lld",&a[i]);
	for(int i=1,a,b;i<n;i++)	scanf("%lld %lld",&a,&b),add(a,b),add(b,a);
	dfs(1,0);	build(1,1,n);
	while(q--){
		int op,x,y;	scanf("%lld %lld",&op,&x);
		if(op==1)	scanf("%lld",&y),change(1,st[x],ed[x],y);
		else	printf("%lld\n",ask(1,st[x],ed[x]));
	}
	return 0;
}