题目描述
已知有 n 个节点,有 n−1 条边,形成一个树的结构。
给定一个根节点 k,每个节点都有一个权值,节点i的权值为vi。
给 m 个操作,操作有两种类型:
1 a x :表示将节点 aa 的权值加上 x
2 a :表示求 a 节点的子树上所有节点的和(包括 a 节点本身)
题解:
通过dfs序将一棵树当做一个数组处理
in[x]和out[x]为x在这个数组里的左右界限
x的子树也在区间[ in[x], out[x] ]里
所以查询a节点的子树上的所有节点的和,就是区间查询
而第一个操作就是单点修改,记得in[x]和out[x]都要修改
最后区间查询的答案要除以2,因为每个数算了两遍
可以用树状数组,也可以用线段树
最近在练码量,所以写线段树,但是段错误了。。。也不知道怎么回事
求大佬帮忙看看
我换了一个dfs序的方式就过了。。人傻了
代码;
#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
inline int read(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();//s=(s<<3)+(s<<1)+(ch^48);
return s*w;
}
const int maxn=4e6+9;
int v[maxn];
vector<int>edge[maxn];
int in[maxn];
int out[maxn];
int pos[maxn];
int cnt=0;
void dfs(int root,int fa){
in[root]=++cnt;
pos[cnt]=root;
for(int i=0;i<edge[root].size();i++)
{
int v=edge[root][i];
if(v==fa)continue;
dfs(v,root);
}
out[root]=++cnt;
pos[cnt]=root;
}
int tree[maxn];
void pushup(int root)
{
tree[root]=tree[root*2+1]+tree[root*2];
}
void bulid(int root,int l,int r)
{
if(l==r)
{
tree[root]=v[pos[l]];
return ;
}
int mid=l+r>>1;
bulid(root*2,l,mid);
bulid(root*2+1,mid+1,r);
pushup(root);
}
void update(int id,int x,int root,int l,int r)
{
if(l==r)
{
tree[root]+=x;
return ;
}
int mid=l+r>>1;
if(id<=mid)update(id,x,root*2,l,mid);
else update(id,x,root*2+1,mid+1,r);
pushup(root);
}
ll query(int root,int l,int r,int L,int R)
{
if(L==l&&r==R)
{
return tree[root];
}
//pushup(root);
// ll ans=0;
int mid=l+r>>1;
// if(L<mid)ans+=query(root*2,l,mid,L,R);
// if(mid<R)ans+=query(root*2+1,mid+1,r,L,R);
// return ans;
if(R<=mid)return query(root*2,l,mid,L,R);
else if(L>mid)return query(root*2+1,mid+1,r,L,R);
else return query(root*2,l,mid,L,mid)+query(root*2+1,mid+1,r,mid+1,R);
}
int main()
{
ios::sync_with_stdio(false);
int n,m,root;
cin>>n>>m>>root;
for(int i=1;i<=n;i++)
{
cin>>v[i];
}
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs(root,0);
bulid(1,1,2*n);
for(int i=1;i<=m;i++)//操作
{
int op;
cin>>op;
if(op==1)//单点修改
{
int p,x;
cin>>p>>x;
update(in[p],x,1,1,2*n);
update(out[p],x,1,1,2*n);
}
else if(op==2)//区间查询
{
int x;
cin>>x;
ll sum=0;
sum=query(1,1,2*n,in[x],out[x]);
cout<<sum/2<<endl;
}
}
return 0;
}
正确代码
#include<bits/stdc++.h>
//typedef long long ll;
using namespace std;
inline int read(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();//s=(s<<3)+(s<<1)+(ch^48);
return s*w;
}
const int maxn=1e6+9;
int v[maxn];
vector<int>edge[maxn];
int in[maxn];
int out[maxn];
int pos[maxn];
int cnt=0;
void dfs(int root,int fa){
in[root]=++cnt;
pos[cnt]=root;
for(int i=0;i<edge[root].size();i++)
{
int v=edge[root][i];
if(v==fa)continue;
dfs(v,root);
}
out[root]=cnt;
//out[root]=++cnt;
//pos[cnt]=root;
}
int tree[maxn<<2];
void pushup(int root)
{
tree[root]=tree[root*2+1]+tree[root*2];
}
void bulid(int root,int l,int r)
{
if(l==r)
{
tree[root]=v[pos[l]];
return ;
}
int mid=l+r>>1;
bulid(root*2,l,mid);
bulid(root*2+1,mid+1,r);
pushup(root);
}
void update(int id,int x,int root,int l,int r)
{
if(l==r)
{
tree[root]+=x;
return ;
}
int mid=l+r>>1;
if(id<=mid)update(id,x,root*2,l,mid);
else update(id,x,root*2+1,mid+1,r);
pushup(root);
}
int query(int root,int l,int r,int L,int R)
{
if(L==l&&r==R)
{
return tree[root];
}
//pushup(root);
// ll ans=0;
int mid=l+r>>1;
// if(L<mid)ans+=query(root*2,l,mid,L,R);
// if(mid<R)ans+=query(root*2+1,mid+1,r,L,R);
// return ans;
if(R<=mid)return query(root*2,l,mid,L,R);
else if(L>mid)return query(root*2+1,mid+1,r,L,R);
else return query(root*2,l,mid,L,mid)+query(root*2+1,mid+1,r,mid+1,R);
}
int main()
{
//ios::sync_with_stdio(false);
int n,m,root;
//cin>>n>>m>>root;
scanf("%d%d%d",&n,&m,&root);
for(int i=1;i<=n;i++)
{
cin>>v[i];
}
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs(root,0);
bulid(1,1,n);
for(int i=1;i<=m;i++)//操作
{
int op;
scanf("%d",&op);
//cin>>op;
if(op==1)//单点修改
{
int p,x;
scanf("%d%d",&p,&x);
//cin>>p>>x;
update(in[p],x,1,1,n);
// update(out[p],x,1,1,n);
}
else if(op==2)//区间查询
{
int x;
scanf("%d",&x);
//cin>>x;
int sum=0;
sum=query(1,1,n,in[x],out[x]);
cout<<sum<<endl;
}
}
return 0;
}

京公网安备 11010502036488号