题目链接
PEEK68 【模板】动态区间和Ⅱ ‖ 区间修改 + 区间查询
题目描述
给定一个长度为 的数组
。你需要支持两种操作:
- 区间修改:将区间
内的所有元素都加上一个值
。
- 区间和查询:查询区间
内所有元素的和。
解题思路
本题需要同时支持区间的动态修改和区间的查询,是数据结构领域的经典问题。朴素的数组操作(区间修改 ,区间查询
)在大量操作下会超时。
解决此问题的标准高效方法是使用 线段树 (Segment Tree),并结合 懒惰标记 (Lazy Propagation) 来处理区间修改。
-
线段树 线段树是一种二叉树,每个节点代表原数组的一个区间
。
- 根节点代表整个区间
。
- 对于任意节点
(其中
),它的左子节点代表区间
,右子节点代表区间
,其中
。
- 每个节点存储其所代表区间的元素之和
sum
。
- 根节点代表整个区间
-
懒惰标记 如果每次区间修改都更新到叶子节点,复杂度仍然是
。懒惰标记是优化区间修改的关键。
- 每个节点额外维护一个
lazy_tag
,表示“欠”其子孙节点的增量值。 - 当一个修改操作
update(L, R, x)
完全覆盖了某个节点p
所代表的区间[l, r]
时,我们不再继续向下递归。我们直接更新节点p
的sum
(sum += x * (r - l + 1)
),并将增量累加到
p
的lazy_tag
上。 - 当后续操作(查询或修改)需要访问
p
的子节点时,我们先执行一次下推 (push_down
) 操作:将p
的lazy_tag
值应用到其左右子节点的sum
和lazy_tag
上,然后清空p
的lazy_tag
。
- 每个节点额外维护一个
算法流程
build
():根据初始数组递归构建线段树。
update
():从根节点递归修改。对于部分重叠的区间,先
push_down
再递归子节点。query
():从根节点递归查询。对于部分重叠的区间,先
push_down
再递归子节点求和。
性能说明:由于本题数据量较大,为确保通过,Java 和 Python 的实现将采用更快的 I/O 方式。
代码
#include <iostream>
#include <vector>
using namespace std;
const int MAXN = 500005;
long long arr[MAXN];
long long tree[4 * MAXN];
long long lazy[4 * MAXN];
int n, m;
void push_up(int node) {
tree[node] = tree[node * 2] + tree[node * 2 + 1];
}
void push_down(int node, int l, int r) {
if (lazy[node] == 0) return;
int mid = l + (r - l) / 2;
// 更新左子节点的 sum 和 lazy_tag
tree[node * 2] += lazy[node] * (mid - l + 1);
lazy[node * 2] += lazy[node];
// 更新右子节点的 sum 和 lazy_tag
tree[node * 2 + 1] += lazy[node] * (r - (mid + 1) + 1);
lazy[node * 2 + 1] += lazy[node];
// 清空当前节点的 lazy_tag
lazy[node] = 0;
}
void build(int node, int l, int r) {
if (l == r) {
tree[node] = arr[l];
return;
}
int mid = l + (r - l) / 2;
build(node * 2, l, mid);
build(node * 2 + 1, mid + 1, r);
push_up(node);
}
void update(int node, int l, int r, int ql, int qr, long long val) {
if (ql <= l && r <= qr) {
tree[node] += val * (r - l + 1);
lazy[node] += val;
return;
}
push_down(node, l, r);
int mid = l + (r - l) / 2;
if (ql <= mid) {
update(node * 2, l, mid, ql, qr, val);
}
if (qr > mid) {
update(node * 2 + 1, mid + 1, r, ql, qr, val);
}
push_up(node);
}
long long query(int node, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) {
return tree[node];
}
push_down(node, l, r);
int mid = l + (r - l) / 2;
long long sum = 0;
if (ql <= mid) {
sum += query(node * 2, l, mid, ql, qr);
}
if (qr > mid) {
sum += query(node * 2 + 1, mid + 1, r, ql, qr);
}
return sum;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
cin >> arr[i];
}
build(1, 1, n);
for (int i = 0; i < m; ++i) {
int type;
cin >> type;
if (type == 1) {
int l, r;
long long x;
cin >> l >> r >> x;
update(1, 1, n, l, r, x);
} else {
int l, r;
cin >> l >> r;
cout << query(1, 1, n, l, r) << "\n";
}
}
return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.IOException;
import java.util.StringTokenizer;
public class Main {
static final int MAXN = 500005;
static long[] arr = new long[MAXN];
static long[] tree = new long[4 * MAXN];
static long[] lazy = new long[4 * MAXN];
static int n, m;
static void pushUp(int node) {
tree[node] = tree[node * 2] + tree[node * 2 + 1];
}
static void pushDown(int node, int l, int r) {
if (lazy[node] == 0) return;
int mid = l + (r - l) / 2;
tree[node * 2] += lazy[node] * (mid - l + 1);
lazy[node * 2] += lazy[node];
tree[node * 2 + 1] += lazy[node] * (r - mid);
lazy[node * 2 + 1] += lazy[node];
lazy[node] = 0;
}
static void build(int node, int l, int r) {
if (l == r) {
tree[node] = arr[l];
return;
}
int mid = l + (r - l) / 2;
build(node * 2, l, mid);
build(node * 2 + 1, mid + 1, r);
pushUp(node);
}
static void update(int node, int l, int r, int ql, int qr, long val) {
if (ql <= l && r <= qr) {
tree[node] += val * (r - l + 1);
lazy[node] += val;
return;
}
pushDown(node, l, r);
int mid = l + (r - l) / 2;
if (ql <= mid) {
update(node * 2, l, mid, ql, qr, val);
}
if (qr > mid) {
update(node * 2 + 1, mid + 1, r, ql, qr, val);
}
pushUp(node);
}
static long query(int node, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) {
return tree[node];
}
pushDown(node, l, r);
int mid = l + (r - l) / 2;
long sum = 0;
if (ql <= mid) {
sum += query(node * 2, l, mid, ql, qr);
}
if (qr > mid) {
sum += query(node * 2 + 1, mid + 1, r, ql, qr);
}
return sum;
}
public static void main(String[] args) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
PrintWriter writer = new PrintWriter(System.out);
StringTokenizer st = new StringTokenizer(reader.readLine());
n = Integer.parseInt(st.nextToken());
m = Integer.parseInt(st.nextToken());
st = new StringTokenizer(reader.readLine());
for (int i = 1; i <= n; i++) {
arr[i] = Long.parseLong(st.nextToken());
}
build(1, 1, n);
for (int i = 0; i < m; i++) {
st = new StringTokenizer(reader.readLine());
int type = Integer.parseInt(st.nextToken());
if (type == 1) {
int l = Integer.parseInt(st.nextToken());
int r = Integer.parseInt(st.nextToken());
long x = Long.parseLong(st.nextToken());
update(1, 1, n, l, r, x);
} else {
int l = Integer.parseInt(st.nextToken());
int r = Integer.parseInt(st.nextToken());
writer.println(query(1, 1, n, l, r));
}
}
writer.flush();
}
}
import sys
# 为应对大数据和深递归,增加递归深度并使用快速I/O
sys.setrecursionlimit(100005)
input = sys.stdin.readline
n, m = 0, 0
arr = []
tree = []
lazy = []
def push_up(node):
tree[node] = tree[node * 2] + tree[node * 2 + 1]
def push_down(node, l, r):
if lazy[node] == 0:
return
mid = l + (r - l) // 2
tree[node * 2] += lazy[node] * (mid - l + 1)
lazy[node * 2] += lazy[node]
tree[node * 2 + 1] += lazy[node] * (r - mid)
lazy[node * 2 + 1] += lazy[node]
lazy[node] = 0
def build(node, l, r):
if l == r:
tree[node] = arr[l]
return
mid = l + (r - l) // 2
build(node * 2, l, mid)
build(node * 2 + 1, mid + 1, r)
push_up(node)
def update(node, l, r, ql, qr, val):
if ql <= l and r <= qr:
tree[node] += val * (r - l + 1)
lazy[node] += val
return
push_down(node, l, r)
mid = l + (r - l) // 2
if ql <= mid:
update(node * 2, l, mid, ql, qr, val)
if qr > mid:
update(node * 2 + 1, mid + 1, r, ql, qr, val)
push_up(node)
def query(node, l, r, ql, qr):
if ql <= l and r <= qr:
return tree[node]
push_down(node, l, r)
mid = l + (r - l) // 2
s = 0
if ql <= mid:
s += query(node * 2, l, mid, ql, qr)
if qr > mid:
s += query(node * 2 + 1, mid + 1, r, ql, qr)
return s
def solve():
global n, m, arr, tree, lazy
n, m = map(int, input().split())
arr = [0] + list(map(int, input().split()))
tree = [0] * (4 * (n + 1))
lazy = [0] * (4 * (n + 1))
build(1, 1, n)
results = []
for _ in range(m):
line = list(map(int, input().split()))
op_type = line[0]
if op_type == 1:
l, r, x = line[1], line[2], line[3]
update(1, 1, n, l, r, x)
else:
l, r = line[1], line[2]
results.append(str(query(1, 1, n, l, r)))
if results:
print("\n".join(results))
solve()
算法及复杂度
- 算法:线段树 (Segment Tree) + 懒惰标记 (Lazy Propagation)
- 时间复杂度:
。建树需要
,后续
次操作每次需要
。
- 空间复杂度:
,用于存储线段树。