题目描述
已知有 n 个节点,有 n−1 条边,形成一个树的结构。
给定一个根节点 k,每个节点都有一个权值,节点i的权值为vi。
给 m 个操作,操作有两种类型:
1 a x :表示将节点 aa 的权值加上 x
2 a :表示求 a 节点的子树上所有节点的和(包括 a 节点本身)
题解:
通过dfs序将一棵树当做一个数组处理
in[x]和out[x]为x在这个数组里的左右界限
x的子树也在区间[ in[x], out[x] ]里
所以查询a节点的子树上的所有节点的和,就是区间查询
而第一个操作就是单点修改,记得in[x]和out[x]都要修改
最后区间查询的答案要除以2,因为每个数算了两遍
可以用树状数组,也可以用线段树
最近在练码量,所以写线段树,但是段错误了。。。也不知道怎么回事
求大佬帮忙看看
我换了一个dfs序的方式就过了。。人傻了
代码;
#include<bits/stdc++.h> typedef long long ll; using namespace std; inline int read(){ int s=0,w=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();} while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();//s=(s<<3)+(s<<1)+(ch^48); return s*w; } const int maxn=4e6+9; int v[maxn]; vector<int>edge[maxn]; int in[maxn]; int out[maxn]; int pos[maxn]; int cnt=0; void dfs(int root,int fa){ in[root]=++cnt; pos[cnt]=root; for(int i=0;i<edge[root].size();i++) { int v=edge[root][i]; if(v==fa)continue; dfs(v,root); } out[root]=++cnt; pos[cnt]=root; } int tree[maxn]; void pushup(int root) { tree[root]=tree[root*2+1]+tree[root*2]; } void bulid(int root,int l,int r) { if(l==r) { tree[root]=v[pos[l]]; return ; } int mid=l+r>>1; bulid(root*2,l,mid); bulid(root*2+1,mid+1,r); pushup(root); } void update(int id,int x,int root,int l,int r) { if(l==r) { tree[root]+=x; return ; } int mid=l+r>>1; if(id<=mid)update(id,x,root*2,l,mid); else update(id,x,root*2+1,mid+1,r); pushup(root); } ll query(int root,int l,int r,int L,int R) { if(L==l&&r==R) { return tree[root]; } //pushup(root); // ll ans=0; int mid=l+r>>1; // if(L<mid)ans+=query(root*2,l,mid,L,R); // if(mid<R)ans+=query(root*2+1,mid+1,r,L,R); // return ans; if(R<=mid)return query(root*2,l,mid,L,R); else if(L>mid)return query(root*2+1,mid+1,r,L,R); else return query(root*2,l,mid,L,mid)+query(root*2+1,mid+1,r,mid+1,R); } int main() { ios::sync_with_stdio(false); int n,m,root; cin>>n>>m>>root; for(int i=1;i<=n;i++) { cin>>v[i]; } for(int i=1;i<n;i++) { int u,v; cin>>u>>v; edge[u].push_back(v); edge[v].push_back(u); } dfs(root,0); bulid(1,1,2*n); for(int i=1;i<=m;i++)//操作 { int op; cin>>op; if(op==1)//单点修改 { int p,x; cin>>p>>x; update(in[p],x,1,1,2*n); update(out[p],x,1,1,2*n); } else if(op==2)//区间查询 { int x; cin>>x; ll sum=0; sum=query(1,1,2*n,in[x],out[x]); cout<<sum/2<<endl; } } return 0; }
正确代码
#include<bits/stdc++.h> //typedef long long ll; using namespace std; inline int read(){ int s=0,w=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();} while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();//s=(s<<3)+(s<<1)+(ch^48); return s*w; } const int maxn=1e6+9; int v[maxn]; vector<int>edge[maxn]; int in[maxn]; int out[maxn]; int pos[maxn]; int cnt=0; void dfs(int root,int fa){ in[root]=++cnt; pos[cnt]=root; for(int i=0;i<edge[root].size();i++) { int v=edge[root][i]; if(v==fa)continue; dfs(v,root); } out[root]=cnt; //out[root]=++cnt; //pos[cnt]=root; } int tree[maxn<<2]; void pushup(int root) { tree[root]=tree[root*2+1]+tree[root*2]; } void bulid(int root,int l,int r) { if(l==r) { tree[root]=v[pos[l]]; return ; } int mid=l+r>>1; bulid(root*2,l,mid); bulid(root*2+1,mid+1,r); pushup(root); } void update(int id,int x,int root,int l,int r) { if(l==r) { tree[root]+=x; return ; } int mid=l+r>>1; if(id<=mid)update(id,x,root*2,l,mid); else update(id,x,root*2+1,mid+1,r); pushup(root); } int query(int root,int l,int r,int L,int R) { if(L==l&&r==R) { return tree[root]; } //pushup(root); // ll ans=0; int mid=l+r>>1; // if(L<mid)ans+=query(root*2,l,mid,L,R); // if(mid<R)ans+=query(root*2+1,mid+1,r,L,R); // return ans; if(R<=mid)return query(root*2,l,mid,L,R); else if(L>mid)return query(root*2+1,mid+1,r,L,R); else return query(root*2,l,mid,L,mid)+query(root*2+1,mid+1,r,mid+1,R); } int main() { //ios::sync_with_stdio(false); int n,m,root; //cin>>n>>m>>root; scanf("%d%d%d",&n,&m,&root); for(int i=1;i<=n;i++) { cin>>v[i]; } for(int i=1;i<n;i++) { int u,v; cin>>u>>v; edge[u].push_back(v); edge[v].push_back(u); } dfs(root,0); bulid(1,1,n); for(int i=1;i<=m;i++)//操作 { int op; scanf("%d",&op); //cin>>op; if(op==1)//单点修改 { int p,x; scanf("%d%d",&p,&x); //cin>>p>>x; update(in[p],x,1,1,n); // update(out[p],x,1,1,n); } else if(op==2)//区间查询 { int x; scanf("%d",&x); //cin>>x; int sum=0; sum=query(1,1,n,in[x],out[x]); cout<<sum<<endl; } } return 0; }