题解:BISHI128 区间加乘与单点求值
题目链接
题目描述
给定长度为 的数组
,支持
次操作,输出均对
取模:
- 操作 1
:对区间
执行
- 操作 2
:对区间
执行
- 操作 3
:输出
解题思路
线段树(懒标记)维护“区间仿射变换”,
。
- 对区间乘:把当前节点的懒标记
更新为
;
- 对区间加:把懒标记
更新为
;
- 单点查询:自顶向下下推懒标记,走到叶子返回值。
由于只做单点查询,不必维护区间和,节点只需存储叶子到该区间的变换即可;但实现上更简洁的做法是仍维护区间和或不维护值,仅在下推时作用到子节点的懒标记。这里实现为“仅维护懒标记并在下推时组合”,查询时沿路径套用即可。
时间复杂度:每次操作/查询 。
代码
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
struct SegTree {
static constexpr i64 MOD = 998244353;
int n; vector<i64> mul, add; // 懒标记:a = mul * a + add (mod M)
SegTree(int n=0): n(n) { mul.assign(4*n+4, 1); add.assign(4*n+4, 0); }
void build(int idx, int l, int r, const vector<i64>& a){
mul[idx] = 1; add[idx] = 0;
if(l==r) { i64 v=(a[l]%MOD+MOD)%MOD; add[idx] = v; mul[idx] = 0; return; }
int m=(l+r)>>1; build(idx<<1,l,m,a); build(idx<<1|1,m+1,r,a);
}
static inline i64 norm(i64 x){ x%=MOD; if(x<0) x+=MOD; return x; }
void apply_mul(int idx, i64 x){ x=norm(x); mul[idx] = (mul[idx]*x)%MOD; add[idx] = (add[idx]*x)%MOD; }
void apply_add(int idx, i64 x){ x=norm(x); add[idx] = (add[idx] + x)%MOD; }
void push(int idx){
// 子区间先乘后加(按组合律)
apply_mul(idx<<1, mul[idx]); apply_add(idx<<1, add[idx]);
apply_mul(idx<<1|1, mul[idx]); apply_add(idx<<1|1, add[idx]);
mul[idx] = 1; add[idx] = 0;
}
void range_mul(int idx, int l, int r, int ql, int qr, i64 x){
if(ql<=l && r<=qr){ apply_mul(idx, x); return; }
push(idx); int m=(l+r)>>1;
if(ql<=m) range_mul(idx<<1,l,m,ql,qr,x);
if(qr>m) range_mul(idx<<1|1,m+1,r,ql,qr,x);
}
void range_add(int idx, int l, int r, int ql, int qr, i64 x){
if(ql<=l && r<=qr){ apply_add(idx, x); return; }
push(idx); int m=(l+r)>>1;
if(ql<=m) range_add(idx<<1,l,m,ql,qr,x);
if(qr>m) range_add(idx<<1|1,m+1,r,ql,qr,x);
}
i64 point_query(int idx, int l, int r, int p){
if(l==r){ return (add[idx]%MOD+MOD)%MOD; }
push(idx); int m=(l+r)>>1;
if(p<=m) return point_query(idx<<1,l,m,p);
else return point_query(idx<<1|1,m+1,r,p);
}
};
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m; if(!(cin>>n>>m)) return 0;
vector<i64> a(n+1); for(int i=1;i<=n;i++) cin>>a[i];
SegTree st(n); st.build(1,1,n,a);
while(m--){
int op; cin>>op;
if(op==1){ int l,r; i64 x; cin>>l>>r>>x; st.range_add(1,1,n,l,r,x); }
else if(op==2){ int l,r; i64 x; cin>>l>>r>>x; st.range_mul(1,1,n,l,r,x); }
else { int p; cin>>p; cout<<st.point_query(1,1,n,p)<<'\n'; }
}
return 0;
}
import java.io.*;
public class Main {
static class FastScanner {
private final InputStream in; private final byte[] buf = new byte[1<<16];
private int p=0,l=0; FastScanner(InputStream is){in=is;}
private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
long nextLong() throws IOException { int c; long s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
}
static class SegTree {
static final long MOD = 998244353L;
int n; long[] mul, add; // 懒标记:a = mul * a + add (mod M)
SegTree(int n){ this.n=n; mul=new long[4*n+4]; add=new long[4*n+4]; for(int i=0;i<mul.length;i++){ mul[i]=1; add[i]=0; } }
void build(int idx,int l,int r,long[] a){ mul[idx]=1; add[idx]=0; if(l==r){ long v=a[l]%MOD; if(v<0) v+=MOD; add[idx]=v; mul[idx]=0; return; } int m=(l+r)>>1; build(idx<<1,l,m,a); build(idx<<1|1,m+1,r,a); }
static long norm(long x){ x%=MOD; if(x<0) x+=MOD; return x; }
void applyMul(int idx,long x){ x=norm(x); mul[idx]=(mul[idx]*x)%MOD; add[idx]=(add[idx]*x)%MOD; }
void applyAdd(int idx,long x){ x=norm(x); add[idx]=(add[idx]+x)%MOD; }
void push(int idx){ applyMul(idx<<1, mul[idx]); applyAdd(idx<<1, add[idx]); applyMul(idx<<1|1, mul[idx]); applyAdd(idx<<1|1, add[idx]); mul[idx]=1; add[idx]=0; }
void rangeMul(int idx,int l,int r,int ql,int qr,long x){ if(ql<=l&&r<=qr){ applyMul(idx,x); return; } push(idx); int m=(l+r)>>1; if(ql<=m) rangeMul(idx<<1,l,m,ql,qr,x); if(qr>m) rangeMul(idx<<1|1,m+1,r,ql,qr,x); }
void rangeAdd(int idx,int l,int r,int ql,int qr,long x){ if(ql<=l&&r<=qr){ applyAdd(idx,x); return; } push(idx); int m=(l+r)>>1; if(ql<=m) rangeAdd(idx<<1,l,m,ql,qr,x); if(qr>m) rangeAdd(idx<<1|1,m+1,r,ql,qr,x); }
long pointQuery(int idx,int l,int r,int p){ if(l==r) return (add[idx]%MOD+MOD)%MOD; push(idx); int m=(l+r)>>1; return (p<=m)? pointQuery(idx<<1,l,m,p) : pointQuery(idx<<1|1,m+1,r,p); }
}
public static void main(String[] args) throws Exception {
FastScanner fs = new FastScanner(System.in);
int n = fs.nextInt();
int m = fs.nextInt();
long[] a = new long[n+1];
for(int i=1;i<=n;i++) a[i]=fs.nextLong();
SegTree st = new SegTree(n);
st.build(1,1,n,a);
StringBuilder out = new StringBuilder();
while(m-- > 0){
int op = fs.nextInt();
if(op==1){ int l=fs.nextInt(), r=fs.nextInt(); long x=fs.nextLong(); st.rangeAdd(1,1,n,l,r,x); }
else if(op==2){ int l=fs.nextInt(), r=fs.nextInt(); long x=fs.nextLong(); st.rangeMul(1,1,n,l,r,x); }
else { int p=fs.nextInt(); out.append(st.pointQuery(1,1,n,p)).append('\n'); }
}
System.out.print(out.toString());
}
}
import sys
data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); m = int(next(it))
a = [0]*(n+1)
for i in range(1, n+1):
a[i] = int(next(it))
# 懒标记:对整段应用 a = mul * a + add (mod M)
size = 4*n + 5
mul = [1]*size
add = [0]*size
MOD = 998244353
def build(idx, l, r):
mul[idx] = 1; add[idx] = 0
if l == r:
mul[idx] = 0
add[idx] = a[l]
return
mid = (l + r) >> 1
build(idx<<1, l, mid)
build(idx<<1|1, mid+1, r)
def apply_mul(idx, x):
x %= MOD
mul[idx] = (mul[idx] * x) % MOD
add[idx] = (add[idx] * x) % MOD
def apply_add(idx, x):
add[idx] = (add[idx] + x) % MOD
def push(idx):
apply_mul(idx<<1, mul[idx]); apply_add(idx<<1, add[idx])
apply_mul(idx<<1|1, mul[idx]); apply_add(idx<<1|1, add[idx])
mul[idx] = 1; add[idx] = 0
def range_mul(idx, l, r, ql, qr, x):
if ql <= l and r <= qr:
apply_mul(idx, x)
return
push(idx)
mid = (l + r) >> 1
if ql <= mid: range_mul(idx<<1, l, mid, ql, qr, x)
if qr > mid: range_mul(idx<<1|1, mid+1, r, ql, qr, x)
def range_add(idx, l, r, ql, qr, x):
if ql <= l and r <= qr:
apply_add(idx, x)
return
push(idx)
mid = (l + r) >> 1
if ql <= mid: range_add(idx<<1, l, mid, ql, qr, x)
if qr > mid: range_add(idx<<1|1, mid+1, r, ql, qr, x)
def point_query(idx, l, r, p):
if l == r:
return add[idx] % MOD
push(idx)
mid = (l + r) >> 1
if p <= mid:
return point_query(idx<<1, l, mid, p)
else:
return point_query(idx<<1|1, mid+1, r, p)
build(1, 1, n)
out_lines = []
for _ in range(m):
op = int(next(it))
if op == 1:
l = int(next(it)); r = int(next(it)); x = int(next(it)) % MOD
range_add(1, 1, n, l, r, x)
elif op == 2:
l = int(next(it)); r = int(next(it)); x = int(next(it)) % MOD
range_mul(1, 1, n, l, r, x)
else:
p = int(next(it))
out_lines.append(str(point_query(1, 1, n, p)))
sys.stdout.write('\n'.join(out_lines))
算法及复杂度
- 算法:线段树懒标记维护仿射变换
- 时间复杂度:
(每次操作/查询)
- 空间复杂度: