题目链接
题目描述
给定一个长度为 的数组
。你需要支持三种操作:
- 区间加数:将区间
内的所有元素都加上一个值
。
- 区间乘数:将区间
内的所有元素都乘上一个值
。
- 单点求值:查询下标为
的元素的值,需要对模数
998244353
取模。
解题思路
本题是线段树的经典应用,涉及到两种不同优先级的区间修改操作:加法和乘法。
由于 (a + x) * y
和 a * y + x
的结果不同,我们必须定义一个统一的运算顺序,并设计对应的懒惰标记下推逻辑。通常,我们规定运算的优先级为先乘后加。
懒惰标记
我们需要为线段树的每个节点维护两个懒惰标记:
add_tag
:加法标记,初始为。
mul_tag
:乘法标记,初始为。
一个节点 node
所代表的区间内,每个元素的真实值 val
应该是 val = val * mul_tag + add_tag
。
标记下推 (push_down)
这是本题的核心。当我们将父节点 p
的标记 (p_mul
, p_add
) 下推给子节点 c
(c_mul
, c_add
) 时:
-
更新子节点的区间和:
c_sum_new = (c_sum * p_mul + p_add * len) % P
其中len
是子节点c
所代表区间的长度。 -
更新子节点的标记:
c_mul_new = (c_mul * p_mul) % P
c_add_new = (c_add * p_mul + p_add) % P
这个逻辑保证了“先乘后加”的优先级:子节点原有的加法标记c_add
也要先被父节点的乘法p_mul
影响,然后再叠加上父节点的加法p_add
。
区间修改
- 区间加
x
: 对于完全覆盖的节点,add_tag
加上x
,区间和sum
加上x * len
。 - 区间乘
x
: 对于完全覆盖的节点,mul_tag
和add_tag
都乘以x
,区间和sum
也乘以x
。
单点查询
从根节点递归到对应的叶子节点。在路径上,每次进入子节点前都执行 push_down
,以确保将所有祖先节点的修改都应用下去。到达叶子节点时,其 sum
值即为答案。
注意:在处理取模运算时,由于中间结果可能为负数,需要使用 (x % P + P) % P
的技巧来确保结果始终为正。
代码
#include <iostream>
#include <vector>
using namespace std;
const int MAXN = 100005;
long long arr[MAXN];
long long tree[4 * MAXN];
long long mul_tag[4 * MAXN];
long long add_tag[4 * MAXN];
int n, q;
const long long p = 998244353;
void push_up(int node) {
tree[node] = (tree[node * 2] + tree[node * 2 + 1]) % p;
}
void push_down(int node, int l, int r) {
if (mul_tag[node] == 1 && add_tag[node] == 0) return;
int mid = l + (r - l) / 2;
int lc = node * 2, rc = node * 2 + 1;
// 更新左子节点
tree[lc] = (tree[lc] * mul_tag[node] + add_tag[node] * (mid - l + 1)) % p;
tree[lc] = (tree[lc] + p) % p;
mul_tag[lc] = (mul_tag[lc] * mul_tag[node]) % p;
add_tag[lc] = (add_tag[lc] * mul_tag[node] + add_tag[node]) % p;
// 更新右子节点
tree[rc] = (tree[rc] * mul_tag[node] + add_tag[node] * (r - mid)) % p;
tree[rc] = (tree[rc] + p) % p;
mul_tag[rc] = (mul_tag[rc] * mul_tag[node]) % p;
add_tag[rc] = (add_tag[rc] * mul_tag[node] + add_tag[node]) % p;
// 重置当前节点标记
mul_tag[node] = 1;
add_tag[node] = 0;
}
void build(int node, int l, int r) {
mul_tag[node] = 1;
add_tag[node] = 0;
if (l == r) {
tree[node] = (arr[l] % p + p) % p;
return;
}
int mid = l + (r - l) / 2;
build(node * 2, l, mid);
build(node * 2 + 1, mid + 1, r);
push_up(node);
}
void update_add(int node, int l, int r, int ql, int qr, long long val) {
if (ql <= l && r <= qr) {
tree[node] = (tree[node] + val * (r - l + 1));
tree[node] = (tree[node] % p + p) % p;
add_tag[node] = (add_tag[node] + val);
add_tag[node] = (add_tag[node] % p + p) % p;
return;
}
push_down(node, l, r);
int mid = l + (r - l) / 2;
if (ql <= mid) update_add(node * 2, l, mid, ql, qr, val);
if (qr > mid) update_add(node * 2 + 1, mid + 1, r, ql, qr, val);
push_up(node);
}
void update_mul(int node, int l, int r, int ql, int qr, long long val) {
if (ql <= l && r <= qr) {
tree[node] = (tree[node] * val);
tree[node] = (tree[node] % p + p) % p;
mul_tag[node] = (mul_tag[node] * val);
mul_tag[node] = (mul_tag[node] % p + p) % p;
add_tag[node] = (add_tag[node] * val);
add_tag[node] = (add_tag[node] % p + p) % p;
return;
}
push_down(node, l, r);
int mid = l + (r - l) / 2;
if (ql <= mid) update_mul(node * 2, l, mid, ql, qr, val);
if (qr > mid) update_mul(node * 2 + 1, mid + 1, r, ql, qr, val);
push_up(node);
}
long long query(int node, int l, int r, int pos) {
if (l == r) {
return tree[node];
}
push_down(node, l, r);
int mid = l + (r - l) / 2;
if (pos <= mid) {
return query(node * 2, l, mid, pos);
} else {
return query(node * 2 + 1, mid + 1, r, pos);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
cin >> n >> q;
for (int i = 1; i <= n; ++i) {
cin >> arr[i];
}
build(1, 1, n);
for (int i = 0; i < q; ++i) {
int type;
cin >> type;
if (type == 1) { // 加法
int l, r;
long long x;
cin >> l >> r >> x;
update_add(1, 1, n, l, r, x);
} else if (type == 2) { // 乘法
int l, r;
long long x;
cin >> l >> r >> x;
update_mul(1, 1, n, l, r, x);
} else {
int pos;
cin >> pos;
cout << query(1, 1, n, pos) << "\n";
}
}
return 0;
}
import java.util.Scanner;
public class Main {
static final int MAXN = 100005;
static long[] arr = new long[MAXN];
static long[] tree = new long[4 * MAXN];
static long[] mulTag = new long[4 * MAXN];
static long[] addTag = new long[4 * MAXN];
static int n, q;
static final long p = 998244353;
static void pushUp(int node) {
tree[node] = (tree[node * 2] + tree[node * 2 + 1]) % p;
}
static void pushDown(int node, int l, int r) {
if (mulTag[node] == 1 && addTag[node] == 0) return;
int mid = l + (r - l) / 2;
int lc = node * 2, rc = node * 2 + 1;
tree[lc] = (tree[lc] * mulTag[node] + addTag[node] * (mid - l + 1));
tree[lc] = (tree[lc] % p + p) % p;
mulTag[lc] = (mulTag[lc] * mulTag[node]) % p;
addTag[lc] = (addTag[lc] * mulTag[node] + addTag[node]) % p;
tree[rc] = (tree[rc] * mulTag[node] + addTag[node] * (r - mid));
tree[rc] = (tree[rc] % p + p) % p;
mulTag[rc] = (mulTag[rc] * mulTag[node]) % p;
addTag[rc] = (addTag[rc] * mulTag[node] + addTag[node]) % p;
mulTag[node] = 1;
addTag[node] = 0;
}
static void build(int node, int l, int r) {
mulTag[node] = 1;
addTag[node] = 0;
if (l == r) {
tree[node] = (arr[l] % p + p) % p;
return;
}
int mid = l + (r - l) / 2;
build(node * 2, l, mid);
build(node * 2 + 1, mid + 1, r);
pushUp(node);
}
static void updateAdd(int node, int l, int r, int ql, int qr, long val) {
if (ql <= l && r <= qr) {
tree[node] = (tree[node] + val * (r - l + 1));
tree[node] = (tree[node] % p + p) % p;
addTag[node] = (addTag[node] + val);
addTag[node] = (addTag[node] % p + p) % p;
return;
}
pushDown(node, l, r);
int mid = l + (r - l) / 2;
if (ql <= mid) updateAdd(node * 2, l, mid, ql, qr, val);
if (qr > mid) updateAdd(node * 2 + 1, mid + 1, r, ql, qr, val);
pushUp(node);
}
static void updateMul(int node, int l, int r, int ql, int qr, long val) {
if (ql <= l && r <= qr) {
tree[node] = (tree[node] * val);
tree[node] = (tree[node] % p + p) % p;
mulTag[node] = (mulTag[node] * val);
mulTag[node] = (mulTag[node] % p + p) % p;
addTag[node] = (addTag[node] * val);
addTag[node] = (addTag[node] % p + p) % p;
return;
}
pushDown(node, l, r);
int mid = l + (r - l) / 2;
if (ql <= mid) updateMul(node * 2, l, mid, ql, qr, val);
if (qr > mid) updateMul(node * 2 + 1, mid + 1, r, ql, qr, val);
pushUp(node);
}
static long query(int node, int l, int r, int pos) {
if (l == r) {
return tree[node];
}
pushDown(node, l, r);
int mid = l + (r - l) / 2;
if (pos <= mid) {
return query(node * 2, l, mid, pos);
} else {
return query(node * 2 + 1, mid + 1, r, pos);
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
q = sc.nextInt();
for (int i = 1; i <= n; i++) {
arr[i] = sc.nextLong();
}
build(1, 1, n);
for (int i = 0; i < q; i++) {
int type = sc.nextInt();
if (type == 1) { // 加法
int l = sc.nextInt();
int r = sc.nextInt();
long x = sc.nextLong();
updateAdd(1, 1, n, l, r, x);
} else if (type == 2) { // 乘法
int l = sc.nextInt();
int r = sc.nextInt();
long x = sc.nextLong();
updateMul(1, 1, n, l, r, x);
} else {
int pos = sc.nextInt();
System.out.println(query(1, 1, n, pos));
}
}
}
}
import sys
# 增加递归深度
sys.setrecursionlimit(100005)
n, q = 0, 0
p = 998244353
arr = []
tree = []
mul_tag = []
add_tag = []
def push_up(node):
tree[node] = (tree[node * 2] + tree[node * 2 + 1]) % p
def push_down(node, l, r):
if mul_tag[node] == 1 and add_tag[node] == 0:
return
mid = l + (r - l) // 2
lc, rc = node * 2, node * 2 + 1
tree[lc] = (tree[lc] * mul_tag[node] + add_tag[node] * (mid - l + 1)) % p
mul_tag[lc] = (mul_tag[lc] * mul_tag[node]) % p
add_tag[lc] = (add_tag[lc] * mul_tag[node] + add_tag[node]) % p
tree[rc] = (tree[rc] * mul_tag[node] + add_tag[node] * (r - mid)) % p
mul_tag[rc] = (mul_tag[rc] * mul_tag[node]) % p
add_tag[rc] = (add_tag[rc] * mul_tag[node] + add_tag[node]) % p
mul_tag[node] = 1
add_tag[node] = 0
def build(node, l, r):
mul_tag[node] = 1
add_tag[node] = 0
if l == r:
tree[node] = arr[l] % p
return
mid = l + (r - l) // 2
build(node * 2, l, mid)
build(node * 2 + 1, mid + 1, r)
push_up(node)
def update_add(node, l, r, ql, qr, val):
if ql <= l and r <= qr:
tree[node] = (tree[node] + val * (r - l + 1)) % p
add_tag[node] = (add_tag[node] + val) % p
return
push_down(node, l, r)
mid = l + (r - l) // 2
if ql <= mid:
update_add(node * 2, l, mid, ql, qr, val)
if qr > mid:
update_add(node * 2 + 1, mid + 1, r, ql, qr, val)
push_up(node)
def update_mul(node, l, r, ql, qr, val):
if ql <= l and r <= qr:
tree[node] = (tree[node] * val) % p
mul_tag[node] = (mul_tag[node] * val) % p
add_tag[node] = (add_tag[node] * val) % p
return
push_down(node, l, r)
mid = l + (r - l) // 2
if ql <= mid:
update_mul(node * 2, l, mid, ql, qr, val)
if qr > mid:
update_mul(node * 2 + 1, mid + 1, r, ql, qr, val)
push_up(node)
def query(node, l, r, pos):
if l == r:
return tree[node]
push_down(node, l, r)
mid = l + (r - l) // 2
if pos <= mid:
return query(node * 2, l, mid, pos)
else:
return query(node * 2 + 1, mid + 1, r, pos)
def solve():
global n, q, p, arr, tree, mul_tag, add_tag
line1 = list(map(int, sys.stdin.readline().split()))
n, q = line1[0], line1[1]
arr = [0] + list(map(int, sys.stdin.readline().split()))
tree = [0] * (4 * (n + 1))
mul_tag = [0] * (4 * (n + 1))
add_tag = [0] * (4 * (n + 1))
build(1, 1, n)
results = []
for _ in range(q):
line = list(map(int, sys.stdin.readline().split()))
op_type = line[0]
if op_type == 1: # 加法
l, r, x = line[1], line[2], line[3]
update_add(1, 1, n, l, r, x)
elif op_type == 2: # 乘法
l, r, x = line[1], line[2], line[3]
update_mul(1, 1, n, l, r, x)
else: # 查询
pos = line[1]
results.append(str(query(1, 1, n, pos)))
if results:
sys.stdout.write("\n".join(results) + "\n")
solve()
算法及复杂度
- 算法:线段树 (Segment Tree) + 双重懒惰标记 (加法与乘法)
- 时间复杂度:
。建树需要
,后续
次操作每次需要
。
- 空间复杂度:
,用于存储线段树的节点信息及懒惰标记。