题目链接

PEEK70 区间加乘与单点求值

题目描述

给定一个长度为 的数组 。你需要支持三种操作:

  1. 区间加数:将区间 内的所有元素都加上一个值
  2. 区间乘数:将区间 内的所有元素都乘上一个值
  3. 单点求值:查询下标为 的元素的值,需要对模数 998244353 取模。

解题思路

本题是线段树的经典应用,涉及到两种不同优先级的区间修改操作:加法和乘法。

由于 (a + x) * ya * y + x 的结果不同,我们必须定义一个统一的运算顺序,并设计对应的懒惰标记下推逻辑。通常,我们规定运算的优先级为先乘后加

懒惰标记

我们需要为线段树的每个节点维护两个懒惰标记:

  • add_tag:加法标记,初始为
  • mul_tag:乘法标记,初始为

一个节点 node 所代表的区间内,每个元素的真实值 val 应该是 val = val * mul_tag + add_tag

标记下推 (push_down)

这是本题的核心。当我们将父节点 p 的标记 (p_mul, p_add) 下推给子节点 c (c_mul, c_add) 时:

  1. 更新子节点的区间和c_sum_new = (c_sum * p_mul + p_add * len) % P 其中 len 是子节点 c 所代表区间的长度。

  2. 更新子节点的标记

    • c_mul_new = (c_mul * p_mul) % P
    • c_add_new = (c_add * p_mul + p_add) % P 这个逻辑保证了“先乘后加”的优先级:子节点原有的加法标记 c_add 也要先被父节点的乘法 p_mul 影响,然后再叠加上父节点的加法 p_add

区间修改

  • 区间加 x: 对于完全覆盖的节点,add_tag 加上 x,区间和 sum 加上 x * len
  • 区间乘 x: 对于完全覆盖的节点,mul_tagadd_tag 都乘以 x,区间和 sum 也乘以 x

单点查询

从根节点递归到对应的叶子节点。在路径上,每次进入子节点前都执行 push_down,以确保将所有祖先节点的修改都应用下去。到达叶子节点时,其 sum 值即为答案。

注意:在处理取模运算时,由于中间结果可能为负数,需要使用 (x % P + P) % P 的技巧来确保结果始终为正。

代码

#include <iostream>
#include <vector>

using namespace std;

const int MAXN = 100005;
long long arr[MAXN];
long long tree[4 * MAXN];
long long mul_tag[4 * MAXN];
long long add_tag[4 * MAXN];
int n, q;
const long long p = 998244353;

void push_up(int node) {
    tree[node] = (tree[node * 2] + tree[node * 2 + 1]) % p;
}

void push_down(int node, int l, int r) {
    if (mul_tag[node] == 1 && add_tag[node] == 0) return;

    int mid = l + (r - l) / 2;
    int lc = node * 2, rc = node * 2 + 1;

    // 更新左子节点
    tree[lc] = (tree[lc] * mul_tag[node] + add_tag[node] * (mid - l + 1)) % p;
    tree[lc] = (tree[lc] + p) % p;
    mul_tag[lc] = (mul_tag[lc] * mul_tag[node]) % p;
    add_tag[lc] = (add_tag[lc] * mul_tag[node] + add_tag[node]) % p;

    // 更新右子节点
    tree[rc] = (tree[rc] * mul_tag[node] + add_tag[node] * (r - mid)) % p;
    tree[rc] = (tree[rc] + p) % p;
    mul_tag[rc] = (mul_tag[rc] * mul_tag[node]) % p;
    add_tag[rc] = (add_tag[rc] * mul_tag[node] + add_tag[node]) % p;

    // 重置当前节点标记
    mul_tag[node] = 1;
    add_tag[node] = 0;
}

void build(int node, int l, int r) {
    mul_tag[node] = 1;
    add_tag[node] = 0;
    if (l == r) {
        tree[node] = (arr[l] % p + p) % p;
        return;
    }
    int mid = l + (r - l) / 2;
    build(node * 2, l, mid);
    build(node * 2 + 1, mid + 1, r);
    push_up(node);
}

void update_add(int node, int l, int r, int ql, int qr, long long val) {
    if (ql <= l && r <= qr) {
        tree[node] = (tree[node] + val * (r - l + 1));
        tree[node] = (tree[node] % p + p) % p;
        add_tag[node] = (add_tag[node] + val);
        add_tag[node] = (add_tag[node] % p + p) % p;
        return;
    }
    push_down(node, l, r);
    int mid = l + (r - l) / 2;
    if (ql <= mid) update_add(node * 2, l, mid, ql, qr, val);
    if (qr > mid) update_add(node * 2 + 1, mid + 1, r, ql, qr, val);
    push_up(node);
}

void update_mul(int node, int l, int r, int ql, int qr, long long val) {
    if (ql <= l && r <= qr) {
        tree[node] = (tree[node] * val);
        tree[node] = (tree[node] % p + p) % p;
        mul_tag[node] = (mul_tag[node] * val);
        mul_tag[node] = (mul_tag[node] % p + p) % p;
        add_tag[node] = (add_tag[node] * val);
        add_tag[node] = (add_tag[node] % p + p) % p;
        return;
    }
    push_down(node, l, r);
    int mid = l + (r - l) / 2;
    if (ql <= mid) update_mul(node * 2, l, mid, ql, qr, val);
    if (qr > mid) update_mul(node * 2 + 1, mid + 1, r, ql, qr, val);
    push_up(node);
}

long long query(int node, int l, int r, int pos) {
    if (l == r) {
        return tree[node];
    }
    push_down(node, l, r);
    int mid = l + (r - l) / 2;
    if (pos <= mid) {
        return query(node * 2, l, mid, pos);
    } else {
        return query(node * 2 + 1, mid + 1, r, pos);
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> n >> q;
    for (int i = 1; i <= n; ++i) {
        cin >> arr[i];
    }

    build(1, 1, n);

    for (int i = 0; i < q; ++i) {
        int type;
        cin >> type;
        if (type == 1) { // 加法
            int l, r;
            long long x;
            cin >> l >> r >> x;
            update_add(1, 1, n, l, r, x);
        } else if (type == 2) { // 乘法
            int l, r;
            long long x;
            cin >> l >> r >> x;
            update_mul(1, 1, n, l, r, x);
        } else {
            int pos;
            cin >> pos;
            cout << query(1, 1, n, pos) << "\n";
        }
    }
    return 0;
}

import java.util.Scanner;

public class Main {
    static final int MAXN = 100005;
    static long[] arr = new long[MAXN];
    static long[] tree = new long[4 * MAXN];
    static long[] mulTag = new long[4 * MAXN];
    static long[] addTag = new long[4 * MAXN];
    static int n, q;
    static final long p = 998244353;

    static void pushUp(int node) {
        tree[node] = (tree[node * 2] + tree[node * 2 + 1]) % p;
    }

    static void pushDown(int node, int l, int r) {
        if (mulTag[node] == 1 && addTag[node] == 0) return;

        int mid = l + (r - l) / 2;
        int lc = node * 2, rc = node * 2 + 1;

        tree[lc] = (tree[lc] * mulTag[node] + addTag[node] * (mid - l + 1));
        tree[lc] = (tree[lc] % p + p) % p;
        mulTag[lc] = (mulTag[lc] * mulTag[node]) % p;
        addTag[lc] = (addTag[lc] * mulTag[node] + addTag[node]) % p;

        tree[rc] = (tree[rc] * mulTag[node] + addTag[node] * (r - mid));
        tree[rc] = (tree[rc] % p + p) % p;
        mulTag[rc] = (mulTag[rc] * mulTag[node]) % p;
        addTag[rc] = (addTag[rc] * mulTag[node] + addTag[node]) % p;
        
        mulTag[node] = 1;
        addTag[node] = 0;
    }

    static void build(int node, int l, int r) {
        mulTag[node] = 1;
        addTag[node] = 0;
        if (l == r) {
            tree[node] = (arr[l] % p + p) % p;
            return;
        }
        int mid = l + (r - l) / 2;
        build(node * 2, l, mid);
        build(node * 2 + 1, mid + 1, r);
        pushUp(node);
    }

    static void updateAdd(int node, int l, int r, int ql, int qr, long val) {
        if (ql <= l && r <= qr) {
            tree[node] = (tree[node] + val * (r - l + 1));
            tree[node] = (tree[node] % p + p) % p;
            addTag[node] = (addTag[node] + val);
            addTag[node] = (addTag[node] % p + p) % p;
            return;
        }
        pushDown(node, l, r);
        int mid = l + (r - l) / 2;
        if (ql <= mid) updateAdd(node * 2, l, mid, ql, qr, val);
        if (qr > mid) updateAdd(node * 2 + 1, mid + 1, r, ql, qr, val);
        pushUp(node);
    }

    static void updateMul(int node, int l, int r, int ql, int qr, long val) {
        if (ql <= l && r <= qr) {
            tree[node] = (tree[node] * val);
            tree[node] = (tree[node] % p + p) % p;
            mulTag[node] = (mulTag[node] * val);
            mulTag[node] = (mulTag[node] % p + p) % p;
            addTag[node] = (addTag[node] * val);
            addTag[node] = (addTag[node] % p + p) % p;
            return;
        }
        pushDown(node, l, r);
        int mid = l + (r - l) / 2;
        if (ql <= mid) updateMul(node * 2, l, mid, ql, qr, val);
        if (qr > mid) updateMul(node * 2 + 1, mid + 1, r, ql, qr, val);
        pushUp(node);
    }

    static long query(int node, int l, int r, int pos) {
        if (l == r) {
            return tree[node];
        }
        pushDown(node, l, r);
        int mid = l + (r - l) / 2;
        if (pos <= mid) {
            return query(node * 2, l, mid, pos);
        } else {
            return query(node * 2 + 1, mid + 1, r, pos);
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        n = sc.nextInt();
        q = sc.nextInt();
        for (int i = 1; i <= n; i++) {
            arr[i] = sc.nextLong();
        }

        build(1, 1, n);

        for (int i = 0; i < q; i++) {
            int type = sc.nextInt();
            if (type == 1) { // 加法
                int l = sc.nextInt();
                int r = sc.nextInt();
                long x = sc.nextLong();
                updateAdd(1, 1, n, l, r, x);
            } else if (type == 2) { // 乘法
                int l = sc.nextInt();
                int r = sc.nextInt();
                long x = sc.nextLong();
                updateMul(1, 1, n, l, r, x);
            } else {
                int pos = sc.nextInt();
                System.out.println(query(1, 1, n, pos));
            }
        }
    }
}

import sys

# 增加递归深度
sys.setrecursionlimit(100005)

n, q = 0, 0
p = 998244353
arr = []
tree = []
mul_tag = []
add_tag = []

def push_up(node):
    tree[node] = (tree[node * 2] + tree[node * 2 + 1]) % p

def push_down(node, l, r):
    if mul_tag[node] == 1 and add_tag[node] == 0:
        return

    mid = l + (r - l) // 2
    lc, rc = node * 2, node * 2 + 1

    tree[lc] = (tree[lc] * mul_tag[node] + add_tag[node] * (mid - l + 1)) % p
    mul_tag[lc] = (mul_tag[lc] * mul_tag[node]) % p
    add_tag[lc] = (add_tag[lc] * mul_tag[node] + add_tag[node]) % p

    tree[rc] = (tree[rc] * mul_tag[node] + add_tag[node] * (r - mid)) % p
    mul_tag[rc] = (mul_tag[rc] * mul_tag[node]) % p
    add_tag[rc] = (add_tag[rc] * mul_tag[node] + add_tag[node]) % p

    mul_tag[node] = 1
    add_tag[node] = 0

def build(node, l, r):
    mul_tag[node] = 1
    add_tag[node] = 0
    if l == r:
        tree[node] = arr[l] % p
        return
    mid = l + (r - l) // 2
    build(node * 2, l, mid)
    build(node * 2 + 1, mid + 1, r)
    push_up(node)

def update_add(node, l, r, ql, qr, val):
    if ql <= l and r <= qr:
        tree[node] = (tree[node] + val * (r - l + 1)) % p
        add_tag[node] = (add_tag[node] + val) % p
        return
    
    push_down(node, l, r)
    mid = l + (r - l) // 2
    if ql <= mid:
        update_add(node * 2, l, mid, ql, qr, val)
    if qr > mid:
        update_add(node * 2 + 1, mid + 1, r, ql, qr, val)
    push_up(node)

def update_mul(node, l, r, ql, qr, val):
    if ql <= l and r <= qr:
        tree[node] = (tree[node] * val) % p
        mul_tag[node] = (mul_tag[node] * val) % p
        add_tag[node] = (add_tag[node] * val) % p
        return

    push_down(node, l, r)
    mid = l + (r - l) // 2
    if ql <= mid:
        update_mul(node * 2, l, mid, ql, qr, val)
    if qr > mid:
        update_mul(node * 2 + 1, mid + 1, r, ql, qr, val)
    push_up(node)

def query(node, l, r, pos):
    if l == r:
        return tree[node]
    
    push_down(node, l, r)
    mid = l + (r - l) // 2
    if pos <= mid:
        return query(node * 2, l, mid, pos)
    else:
        return query(node * 2 + 1, mid + 1, r, pos)

def solve():
    global n, q, p, arr, tree, mul_tag, add_tag
    line1 = list(map(int, sys.stdin.readline().split()))
    n, q = line1[0], line1[1]

    arr = [0] + list(map(int, sys.stdin.readline().split()))
    
    tree = [0] * (4 * (n + 1))
    mul_tag = [0] * (4 * (n + 1))
    add_tag = [0] * (4 * (n + 1))

    build(1, 1, n)

    results = []
    for _ in range(q):
        line = list(map(int, sys.stdin.readline().split()))
        op_type = line[0]
        if op_type == 1: # 加法
            l, r, x = line[1], line[2], line[3]
            update_add(1, 1, n, l, r, x)
        elif op_type == 2: # 乘法
            l, r, x = line[1], line[2], line[3]
            update_mul(1, 1, n, l, r, x)
        else: # 查询
            pos = line[1]
            results.append(str(query(1, 1, n, pos)))
            
    if results:
        sys.stdout.write("\n".join(results) + "\n")

solve()

算法及复杂度

  • 算法:线段树 (Segment Tree) + 双重懒惰标记 (加法与乘法)
  • 时间复杂度。建树需要 ,后续 次操作每次需要
  • 空间复杂度,用于存储线段树的节点信息及懒惰标记。