前言
写在前面,线段树是一种用于区间处理的数据结构,本篇博客用来记录我学习线段树的刷题过程。
1.区间求和以及单点修改
这里由于只是涉及单点修改操作,所以就不用lazy标记了
hdu1166
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll INF = -1e9;
const ll N =5e4+7;
const ll mod=1e9+7;
inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; }
inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); }
inline ll gcd(ll x, ll y) { return y ? gcd(y, x % y) : x; }
ll qpow(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans *= a;b >>= 1;a *= a; } return ans; } ll qpow(ll a, ll b, ll mod) { ll ans = 1; while (b) { if (b & 1)(ans *= a) %= mod; b >>= 1; (a *= a) %= mod; }return ans % mod; }
inline int lowbit(int x) { return x & (-x); }
ll getinv(ll x){return qpow(x,mod-2,mod);}
ll sum[N<<2],a[N];
void build(int l,int r,int rt){//l,r是区间的左右端点,rt是编号,满二叉树建树
if(l==r){//叶子结点,赋值
sum[rt]=a[l];
return;
}
int mid=(l+r)>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void Add(int i,int j,int l,int r,int rt){//单点更新
sum[rt]+=j;
if(l==r) return;
int mid=(l+r)>>1;
if(i<=mid) Add(i,j,l,mid,rt<<1);//左搜
if(i>mid) Add(i,j,mid+1,r,rt<<1|1);
}
ll query(int a,int b,int l,int r,int rt){//区间求和
if(a<=l&&b>=r) return sum[rt];
int mid=(l+r)>>1;
ll ans=0;
if(a<=mid) ans+=query(a,b,l,mid,rt<<1);
if(b>mid) ans+=query(a,b,mid+1,r,rt<<1|1);
return ans;
}
int main(){
int t,n,j=0;
scanf("%d",&t);
while(t--){
cout<<"Case "<<++j<<":\n";
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%ld",&a[i]);
build(1,n,1);
string q;int i,j;
while(1){
cin>>q;
if(q=="End") break;
scanf("%d%d",&i,&j);
if(q=="Add"){
Add(i,j,1,n,1);
}
else if(q=="Query"){
printf("%lld\n",query(i,j,1,n,1));
}
else if(q=="Sub"){
Add(i,-j,1,n,1);
}
}
}
return 0;
}2.区间修改
区间加法
1.将某区间每一个数加上k。
2.求出某区间每一个数的和。
记得推加法标记
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll INF = -1e9;
const ll N =1e5+7;
const ll mod=1e9+7;
inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; }
inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); }
inline ll gcd(ll x, ll y) { return y ? gcd(y, x % y) : x; }
ll qpow(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans *= a;b >>= 1;a *= a; } return ans; } ll qpow(ll a, ll b, ll mod) { ll ans = 1; while (b) { if (b & 1)(ans *= a) %= mod; b >>= 1; (a *= a) %= mod; }return ans % mod; }
inline int lowbit(int x) { return x & (-x); }
ll getinv(ll x){return qpow(x,mod-2,mod);}
ll a[N],sum[N<<2],add[N<<2];
void build(int l,int r,int rt){
if(l==r){
sum[rt]=a[l];
return;
}
int mid=l+r>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void pushdown(int rt,int m){
if(add[rt]){
add[rt<<1]+=add[rt];
add[rt<<1|1]+=add[rt];
sum[rt<<1]+=(m-(m>>1))*add[rt];
sum[rt<<1|1]+=(m>>1)*add[rt];
add[rt]=0;//传递下去后就要取消本层标记
}
}
void update(int l,int r,int rt,int x,int y,int val){
if(x<=l&&r<=y){
sum[rt]+=(r-l+1)*val;
add[rt]+=val;
return;
}
int mid=(l+r)>>1;
pushdown(rt,r-l+1);
if(x<=mid) update(l,mid,rt<<1,x,y,val);
if(y>mid) update(mid+1,r,rt<<1|1,x,y,val);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
ll qsum(int rt,int l,int r,int x,int y){
if(x<=l&&r<=y){
return sum[rt];
}
pushdown(rt,r-l+1);
ll ans=0;
int mid=l+r>>1;
if(x<=mid) ans+=qsum(rt<<1,l,mid,x,y);
if(y>mid) ans+=qsum(rt<<1|1,mid+1,r,x,y);
return ans;
}
int main(){
int n,m,op,x,y,k;
cin>>n>>m;
for(int i=1;i<=n;i++) a[i]=read();
build(1,n,1);
for(int i=1;i<=m;i++){
cin>>op;
if(op==2){
cin>>x>>y;
cout<<qsum(1,1,n,x,y)<<endl;
continue;
}
if(op==1){
cin>>x>>y>>k;
update(1,n,1,x,y,k);
}
}
}
/*
5
1 2 3 4 5
1 3
*/区间乘法
P3373 【模板】线段树 2
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll INF = -1e9;
const ll N =1e5+7;
inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; }
inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); }
ll a[N],sum[N<<2],add[N<<2],mul[N<<2];
ll mod;
void pushdown(int rt,int l,int r){//这里要先乘后加
if(mul[rt]!=1){
mul[rt << 1]=(mul[rt << 1] * mul[rt]) % mod;
mul[rt << 1| 1]=(mul[rt << 1| 1] * mul[rt]) % mod;
add[rt<<1]=(add[rt<<1]*mul[rt])%mod;
add[rt<<1|1]=(add[rt<<1|1]*mul[rt])%mod;
sum[rt<<1]=(sum[rt<<1]*mul[rt])%mod;
sum[rt<<1|1]=(sum[rt<<1|1]*mul[rt])%mod;
mul[rt]=1;
}
int mid=l+r>>1;
if(add[rt]){
sum[rt<<1]=(sum[rt<<1]+(mid-l+1)*add[rt])%mod;
sum[rt<<1|1]=(sum[rt<<1|1]+(r-mid)*add[rt])%mod;
add[rt<<1]=(add[rt<<1]+add[rt])%mod;
add[rt<<1|1]=(add[rt<<1|1]+add[rt])%mod;
add[rt]=0;//传递下去后就要取消本层标记
}
}
void build(int l,int r,int rt){
mul[rt]=1;
if(l==r){
sum[rt]=a[l];
return;
}
int mid=l+r>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%mod;
}
void update1(int rt,int l,int r,int x,int y,ll z){//乘法,乘z
if(x<=l&&r<=y){
sum[rt]=sum[rt] * z % mod;
add[rt]=add[rt] * z % mod;
mul[rt]=mul[rt] * z % mod;
return;
}
int mid=(l+r)>>1;
pushdown(rt,l,r);
if(x<=mid) update1(rt<<1,l,mid,x,y,z);
if(y>mid) update1(rt<<1|1,mid+1,r,x,y,z);
sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%mod;
}
void update2(int rt,int l,int r,int x,int y,int val){//加法
if(x<=l&&r<=y){
sum[rt]=(sum[rt]+(r-l+1)*val)%mod;
add[rt]=(add[rt]+val)%mod;
return;
}
int mid=(l+r)>>1;
pushdown(rt,l,r);
if(x<=mid) update2(rt<<1,l,mid,x,y,val);
if(y>mid) update2(rt<<1|1,mid+1,r,x,y,val);
sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%mod;
}
ll qsum(int rt,int l,int r,int x,int y){
if(x<=l&&r<=y) return sum[rt];
pushdown(rt,l,r);
ll ans=0;
int mid=l+r>>1;
if(x<=mid) ans+=qsum(rt<<1,l,mid,x,y);
ans%=mod;
if(y>mid) ans+=qsum(rt<<1|1,mid+1,r,x,y);
return ans%mod;
}
int main(){
ll n,m,op,x,y,k;
cin>>n>>m>>mod;
for(int i=1;i<=n;i++) a[i]=read();
build(1,n,1);
for(int i=1;i<=m;i++){
op=read();
if(op==1){
cin>>x>>y>>k;
update1(1,1,n,x,y,k);
}
else if(op==2){
cin>>x>>y>>k;
update2(1,1,n,x,y,k);
}
else{
cin>>x>>y;
cout<<qsum(1,1,n,x,y)<<endl;
}
}
}
区间修改:赋值
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll INF = -1e9;
const ll N =1e5+7;
const ll mod=1e9+7;
inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; }
inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); }
inline ll gcd(ll x, ll y) { return y ? gcd(y, x % y) : x; }
ll qpow(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans *= a;b >>= 1;a *= a; } return ans; } ll qpow(ll a, ll b, ll mod) { ll ans = 1; while (b) { if (b & 1)(ans *= a) %= mod; b >>= 1; (a *= a) %= mod; }return ans % mod; }
inline int lowbit(int x) { return x & (-x); }
ll getinv(ll x){return qpow(x,mod-2,mod);}
ll a[N],sum[N<<2],change[N<<2];
void build(int l,int r,int rt){
if(l==r){
sum[rt]=a[l];
return;
}
int mid=l+r>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void pushdown(int rt,int m){
if(change[rt]){
change[rt<<1]=change[rt];
change[rt<<1|1]=change[rt];
sum[rt<<1]=(m-(m>>1))*change[rt];
sum[rt<<1|1]=(m>>1)*change[rt];
change[rt]=0;//传递下去后就要取消本层标记
}
}
void update(int l,int r,int rt,int x,int y,int val){
if(x<=l&&r<=y){
sum[rt]=(r-l+1)*val;
change[rt]=val;
return;
}
int mid=(l+r)>>1;
pushdown(rt,r-l+1);
if(x<=mid) update(l,mid,rt<<1,x,y,val);
if(y>mid) update(mid+1,r,rt<<1|1,x,y,val);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
ll qsum(int l,int r,int rt,int x,int y){
if(x<=l&&r<=y) return sum[rt];
pushdown(rt,r-l+1);
ll ans=0;
int mid=l+r>>1;
if(x<=mid) ans+=qsum(l,mid,rt<<1,x,y);
if(y>mid) ans+=qsum(mid+1,r,rt<<1|1,x,y);
return ans;
}
int main(){
int n,m,op,x,y,k;
cin>>n;
for(int i=1;i<=n;i++) a[i]=read();
build(1,n,1);
cin>>m;
for(int i=1;i<=m;i++){
cin>>op;
if(op==0){
cin>>x>>y;
cout<<qsum(1,n,1,x,y)<<endl;
}
if(op==1){
cin>>x>>y>>k;
update(1,n,1,x,y,k);
}
}
}区间修改2:求平方和
数据结构
1 l r 询问区间[l,r]内的元素和
2 l r 询问区间[l,r]内的元素的平方 和
3 l r x 将区间[l,r]内的每一个元素都乘上x
4 l r x 将区间[l,r]内的每一个元素都加上x
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll INF = -1e9;
const ll N =1e5+7;
inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; }
inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); }
ll a[N],sum[N<<2],add[N<<2],mul[N<<2],sum1[N<<2];
void pushdown(int rt,int l,int r){//这里要先乘后加
if(mul[rt]!=1){
mul[rt << 1]=(mul[rt << 1] * mul[rt]);
mul[rt << 1| 1]=(mul[rt << 1| 1] * mul[rt]);
sum[rt<<1]=(sum[rt<<1]*mul[rt]);
sum[rt<<1|1]=(sum[rt<<1|1]*mul[rt]);
sum1[rt<<1]*=mul[rt]*mul[rt];
sum1[rt<<1|1]*=mul[rt]*mul[rt];
mul[rt]=1;
}
int mid=l+r>>1;
if(add[rt]){
add[rt<<1]+=add[rt];
add[rt<<1|1]+=add[rt];
ll x=sum[rt<<1];
ll y=sum[rt<<1|1];
sum[rt<<1]=(sum[rt<<1]+(mid-l+1)*add[rt]);
sum[rt<<1|1]=(sum[rt<<1|1]+(r-mid)*add[rt]);
sum1[rt<<1]+=2*x*add[rt]+(mid-l+1)*add[rt]*add[rt];
sum1[rt<<1|1]+=2*y*add[rt]+(r-mid)*add[rt]*add[rt];
add[rt]=0;//传递下去后就要取消本层标记
}
}
void build(int l,int r,int rt){
add[rt]=0;
mul[rt]=1;
if(l==r){
sum[rt]=a[l];
sum1[rt]=a[l]*a[l];
return;
}
int mid=l+r>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=(sum[rt<<1]+sum[rt<<1|1]);
sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1];
}
void update1(int rt,int l,int r,int x,int y,ll z){//乘法,乘z
if(x<=l&&r<=y){
sum[rt]=sum[rt] * z;
sum1[rt]=sum1[rt] * z * z;
mul[rt]=mul[rt] * z;
if(add[rt]) add[rt]=add[rt] * z;
return;
}
int mid=(l+r)>>1;
pushdown(rt,l,r);
if(x<=mid) update1(rt<<1,l,mid,x,y,z);
if(y>mid) update1(rt<<1|1,mid+1,r,x,y,z);
sum[rt]=(sum[rt<<1]+sum[rt<<1|1]);
sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1];
}
void update2(int rt,int l,int r,int x,int y,ll val){//加法
if(x<=l&&r<=y){
ll cnt=sum[rt];
sum[rt]=sum[rt]+(r-l+1)*val;
sum1[rt]+=2*val*cnt+(r-l+1)*val*val;
add[rt]=(add[rt]+val);
return;
}
int mid=(l+r)>>1;
pushdown(rt,l,r);
if(x<=mid) update2(rt<<1,l,mid,x,y,val);
if(y>mid) update2(rt<<1|1,mid+1,r,x,y,val);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1];
}
ll qsum(int rt,int l,int r,int x,int y,int c){
if(x<=l&&r<=y){
if(c==1)
return sum[rt];
else
return sum1[rt];
}
pushdown(rt,l,r);
ll ans=0;
int mid=l+r>>1;
if(x<=mid) ans+=qsum(rt<<1,l,mid,x,y,c);
if(y>mid) ans+=qsum(rt<<1|1,mid+1,r,x,y,c);
return ans;
}
int main(){
ll n,m,op,x,y,k;
cin>>n>>m;
for(int i=1;i<=n;i++) a[i]=read();
build(1,n,1);
for(int i=1;i<=m;i++){
op=read();
if(op==3){
scanf("%lld%lld%lld",&x,&y,&k);
update1(1,1,n,x,y,k);
}
else if(op==4){
scanf("%lld%lld%lld",&x,&y,&k);
update2(1,1,n,x,y,k);
}
else if(op==1||op==2){
cin>>x>>y;
printf("%lld\n",qsum(1,1,n,x,y,op));
}
}
}代码2
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll INF = -1e9;
const ll N =1e5+7;
inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); } while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar(); return s * w; }
inline void write(ll x) { if (!x) { putchar('0'); return; } char F[200]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); }
ll a[N],sum[N<<2],add[N<<2],mul[N<<2],sum1[N<<2];
void pushdown(int rt,int l,int r){//这里要先乘后加
if(mul[rt]!=1){
mul[rt << 1]=(mul[rt << 1] * mul[rt]) ;
mul[rt << 1| 1]=(mul[rt << 1| 1] * mul[rt]) ;
add[rt<<1]=add[rt<<1]*mul[rt];
add[rt<<1|1]=add[rt<<1|1]*mul[rt];
sum[rt<<1]=(sum[rt<<1]*mul[rt]);
sum[rt<<1|1]=(sum[rt<<1|1]*mul[rt]);
sum1[rt<<1]*=mul[rt]*mul[rt];
sum1[rt<<1|1]*=mul[rt]*mul[rt];
mul[rt]=1;
}
int mid=l+r>>1;
if(add[rt]){
add[rt<<1]=(add[rt<<1]+add[rt]);
add[rt<<1|1]=(add[rt<<1|1]+add[rt]);
ll x=sum[rt<<1];
ll y=sum[rt<<1|1];
sum[rt<<1]=(sum[rt<<1]+(mid-l+1)*add[rt]);
sum[rt<<1|1]=(sum[rt<<1|1]+(r-mid)*add[rt]);
sum1[rt<<1]+=2*x*add[rt]+(mid-l+1)*add[rt]*add[rt];
sum1[rt<<1|1]+=2*y*add[rt]+(r-mid)*add[rt]*add[rt];
add[rt]=0;//传递下去后就要取消本层标记
}
}
void build(int l,int r,int rt){
add[rt]=0;
mul[rt]=1;
if(l==r){
sum[rt]=a[l];
sum1[rt]=a[l]*a[l];
return;
}
int mid=l+r>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=(sum[rt<<1]+sum[rt<<1|1]);
sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1];
}
void update1(int rt,int l,int r,int x,int y,ll z){//乘法,乘z
if(x<=l&&r<=y){
sum[rt]=sum[rt] * z;
sum1[rt]=sum1[rt] * z * z;
add[rt]=add[rt] * z;
mul[rt]=mul[rt] * z;
return;
}
int mid=(l+r)>>1;
pushdown(rt,l,r);
if(x<=mid) update1(rt<<1,l,mid,x,y,z);
if(y>mid) update1(rt<<1|1,mid+1,r,x,y,z);
sum[rt]=(sum[rt<<1]+sum[rt<<1|1]);
sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1];
}
void update2(int rt,int l,int r,int x,int y,ll val){//加法
if(x<=l&&r<=y){
ll cnt=sum[rt];
sum[rt]=(sum[rt]+(r-l+1)*val);
sum1[rt]+=2*val*cnt+(r-l+1)*val*val;
add[rt]=(add[rt]+val);
return;
}
int mid=(l+r)>>1;
pushdown(rt,l,r);
if(x<=mid) update2(rt<<1,l,mid,x,y,val);
if(y>mid) update2(rt<<1|1,mid+1,r,x,y,val);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
sum1[rt]=sum1[rt<<1]+sum1[rt<<1|1];
}
ll qsum(int rt,int l,int r,int x,int y,int c){
if(x<=l&&r<=y){
if(c==1)
return sum[rt];
else
return sum1[rt];
}
pushdown(rt,l,r);
ll ans=0;
int mid=l+r>>1;
if(x<=mid) ans+=qsum(rt<<1,l,mid,x,y,c);
if(y>mid) ans+=qsum(rt<<1|1,mid+1,r,x,y,c);
return ans;
}
int main(){
ll n,m,op,x,y,k;
cin>>n>>m;
for(int i=1;i<=n;i++) a[i]=read();
build(1,n,1);
for(int i=1;i<=m;i++){
op=read();
if(op==3){
scanf("%lld%lld%lld",&x,&y,&k);
update1(1,1,n,x,y,k);
}
else if(op==4){
scanf("%lld%lld%lld",&x,&y,&k);
update2(1,1,n,x,y,k);
}
else if(op==1||op==2){
cin>>x>>y;
printf("%lld\n",qsum(1,1,n,x,y,op));
}
}
}
京公网安备 11010502036488号