题解:BISHI128 区间加乘与单点求值

题目链接

区间加乘与单点求值

题目描述

给定长度为 的数组 ,支持 次操作,输出均对 取模:

  • 操作 1 :对区间 执行
  • 操作 2 :对区间 执行
  • 操作 3 :输出

解题思路

线段树(懒标记)维护“区间仿射变换”

  • 对区间乘:把当前节点的懒标记 更新为
  • 对区间加:把懒标记 更新为
  • 单点查询:自顶向下下推懒标记,走到叶子返回值。

由于只做单点查询,不必维护区间和,节点只需存储叶子到该区间的变换即可;但实现上更简洁的做法是仍维护区间和或不维护值,仅在下推时作用到子节点的懒标记。这里实现为“仅维护懒标记并在下推时组合”,查询时沿路径套用即可。

时间复杂度:每次操作/查询

代码

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;

struct SegTree {
    static constexpr i64 MOD = 998244353;
    int n; vector<i64> mul, add; // 懒标记:a = mul * a + add (mod M)
    SegTree(int n=0): n(n) { mul.assign(4*n+4, 1); add.assign(4*n+4, 0); }
    void build(int idx, int l, int r, const vector<i64>& a){
        mul[idx] = 1; add[idx] = 0;
        if(l==r) { i64 v=(a[l]%MOD+MOD)%MOD; add[idx] = v; mul[idx] = 0; return; }
        int m=(l+r)>>1; build(idx<<1,l,m,a); build(idx<<1|1,m+1,r,a);
    }
    static inline i64 norm(i64 x){ x%=MOD; if(x<0) x+=MOD; return x; }
    void apply_mul(int idx, i64 x){ x=norm(x); mul[idx] = (mul[idx]*x)%MOD; add[idx] = (add[idx]*x)%MOD; }
    void apply_add(int idx, i64 x){ x=norm(x); add[idx] = (add[idx] + x)%MOD; }
    void push(int idx){
        // 子区间先乘后加(按组合律)
        apply_mul(idx<<1, mul[idx]); apply_add(idx<<1, add[idx]);
        apply_mul(idx<<1|1, mul[idx]); apply_add(idx<<1|1, add[idx]);
        mul[idx] = 1; add[idx] = 0;
    }
    void range_mul(int idx, int l, int r, int ql, int qr, i64 x){
        if(ql<=l && r<=qr){ apply_mul(idx, x); return; }
        push(idx); int m=(l+r)>>1;
        if(ql<=m) range_mul(idx<<1,l,m,ql,qr,x);
        if(qr>m)  range_mul(idx<<1|1,m+1,r,ql,qr,x);
    }
    void range_add(int idx, int l, int r, int ql, int qr, i64 x){
        if(ql<=l && r<=qr){ apply_add(idx, x); return; }
        push(idx); int m=(l+r)>>1;
        if(ql<=m) range_add(idx<<1,l,m,ql,qr,x);
        if(qr>m)  range_add(idx<<1|1,m+1,r,ql,qr,x);
    }
    i64 point_query(int idx, int l, int r, int p){
        if(l==r){ return (add[idx]%MOD+MOD)%MOD; }
        push(idx); int m=(l+r)>>1;
        if(p<=m) return point_query(idx<<1,l,m,p);
        else return point_query(idx<<1|1,m+1,r,p);
    }
};

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, m; if(!(cin>>n>>m)) return 0;
    vector<i64> a(n+1); for(int i=1;i<=n;i++) cin>>a[i];
    SegTree st(n); st.build(1,1,n,a);
    while(m--){
        int op; cin>>op;
        if(op==1){ int l,r; i64 x; cin>>l>>r>>x; st.range_add(1,1,n,l,r,x); }
        else if(op==2){ int l,r; i64 x; cin>>l>>r>>x; st.range_mul(1,1,n,l,r,x); }
        else { int p; cin>>p; cout<<st.point_query(1,1,n,p)<<'\n'; }
    }
    return 0;
}
import java.io.*;

public class Main {
    static class FastScanner {
        private final InputStream in; private final byte[] buf = new byte[1<<16];
        private int p=0,l=0; FastScanner(InputStream is){in=is;}
        private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
        int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
        long nextLong() throws IOException { int c; long s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
    }

    static class SegTree {
        static final long MOD = 998244353L;
        int n; long[] mul, add; // 懒标记:a = mul * a + add (mod M)
        SegTree(int n){ this.n=n; mul=new long[4*n+4]; add=new long[4*n+4]; for(int i=0;i<mul.length;i++){ mul[i]=1; add[i]=0; } }
        void build(int idx,int l,int r,long[] a){ mul[idx]=1; add[idx]=0; if(l==r){ long v=a[l]%MOD; if(v<0) v+=MOD; add[idx]=v; mul[idx]=0; return; } int m=(l+r)>>1; build(idx<<1,l,m,a); build(idx<<1|1,m+1,r,a); }
        static long norm(long x){ x%=MOD; if(x<0) x+=MOD; return x; }
        void applyMul(int idx,long x){ x=norm(x); mul[idx]=(mul[idx]*x)%MOD; add[idx]=(add[idx]*x)%MOD; }
        void applyAdd(int idx,long x){ x=norm(x); add[idx]=(add[idx]+x)%MOD; }
        void push(int idx){ applyMul(idx<<1, mul[idx]); applyAdd(idx<<1, add[idx]); applyMul(idx<<1|1, mul[idx]); applyAdd(idx<<1|1, add[idx]); mul[idx]=1; add[idx]=0; }
        void rangeMul(int idx,int l,int r,int ql,int qr,long x){ if(ql<=l&&r<=qr){ applyMul(idx,x); return; } push(idx); int m=(l+r)>>1; if(ql<=m) rangeMul(idx<<1,l,m,ql,qr,x); if(qr>m) rangeMul(idx<<1|1,m+1,r,ql,qr,x); }
        void rangeAdd(int idx,int l,int r,int ql,int qr,long x){ if(ql<=l&&r<=qr){ applyAdd(idx,x); return; } push(idx); int m=(l+r)>>1; if(ql<=m) rangeAdd(idx<<1,l,m,ql,qr,x); if(qr>m) rangeAdd(idx<<1|1,m+1,r,ql,qr,x); }
        long pointQuery(int idx,int l,int r,int p){ if(l==r) return (add[idx]%MOD+MOD)%MOD; push(idx); int m=(l+r)>>1; return (p<=m)? pointQuery(idx<<1,l,m,p) : pointQuery(idx<<1|1,m+1,r,p); }
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        int m = fs.nextInt();
        long[] a = new long[n+1];
        for(int i=1;i<=n;i++) a[i]=fs.nextLong();
        SegTree st = new SegTree(n);
        st.build(1,1,n,a);
        StringBuilder out = new StringBuilder();
        while(m-- > 0){
            int op = fs.nextInt();
            if(op==1){ int l=fs.nextInt(), r=fs.nextInt(); long x=fs.nextLong(); st.rangeAdd(1,1,n,l,r,x); }
            else if(op==2){ int l=fs.nextInt(), r=fs.nextInt(); long x=fs.nextLong(); st.rangeMul(1,1,n,l,r,x); }
            else { int p=fs.nextInt(); out.append(st.pointQuery(1,1,n,p)).append('\n'); }
        }
        System.out.print(out.toString());
    }
}
import sys

data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); m = int(next(it))
a = [0]*(n+1)
for i in range(1, n+1):
    a[i] = int(next(it))

# 懒标记:对整段应用 a = mul * a + add (mod M)
size = 4*n + 5
mul = [1]*size
add = [0]*size
MOD = 998244353

def build(idx, l, r):
    mul[idx] = 1; add[idx] = 0
    if l == r:
        mul[idx] = 0
        add[idx] = a[l]
        return
    mid = (l + r) >> 1
    build(idx<<1, l, mid)
    build(idx<<1|1, mid+1, r)

def apply_mul(idx, x):
    x %= MOD
    mul[idx] = (mul[idx] * x) % MOD
    add[idx] = (add[idx] * x) % MOD

def apply_add(idx, x):
    add[idx] = (add[idx] + x) % MOD

def push(idx):
    apply_mul(idx<<1, mul[idx]); apply_add(idx<<1, add[idx])
    apply_mul(idx<<1|1, mul[idx]); apply_add(idx<<1|1, add[idx])
    mul[idx] = 1; add[idx] = 0

def range_mul(idx, l, r, ql, qr, x):
    if ql <= l and r <= qr:
        apply_mul(idx, x)
        return
    push(idx)
    mid = (l + r) >> 1
    if ql <= mid: range_mul(idx<<1, l, mid, ql, qr, x)
    if qr > mid:  range_mul(idx<<1|1, mid+1, r, ql, qr, x)

def range_add(idx, l, r, ql, qr, x):
    if ql <= l and r <= qr:
        apply_add(idx, x)
        return
    push(idx)
    mid = (l + r) >> 1
    if ql <= mid: range_add(idx<<1, l, mid, ql, qr, x)
    if qr > mid:  range_add(idx<<1|1, mid+1, r, ql, qr, x)

def point_query(idx, l, r, p):
    if l == r:
        return add[idx] % MOD
    push(idx)
    mid = (l + r) >> 1
    if p <= mid:
        return point_query(idx<<1, l, mid, p)
    else:
        return point_query(idx<<1|1, mid+1, r, p)

build(1, 1, n)

out_lines = []
for _ in range(m):
    op = int(next(it))
    if op == 1:
        l = int(next(it)); r = int(next(it)); x = int(next(it)) % MOD
        range_add(1, 1, n, l, r, x)
    elif op == 2:
        l = int(next(it)); r = int(next(it)); x = int(next(it)) % MOD
        range_mul(1, 1, n, l, r, x)
    else:
        p = int(next(it))
        out_lines.append(str(point_query(1, 1, n, p)))

sys.stdout.write('\n'.join(out_lines))

算法及复杂度

  • 算法:线段树懒标记维护仿射变换
  • 时间复杂度:(每次操作/查询)
  • 空间复杂度: