题目链接

PEEK68 【模板】动态区间和Ⅱ ‖ 区间修改 + 区间查询

题目描述

给定一个长度为 的数组 。你需要支持两种操作:

  1. 区间修改:将区间 内的所有元素都加上一个值
  2. 区间和查询:查询区间 内所有元素的和。

解题思路

本题需要同时支持区间的动态修改和区间的查询,是数据结构领域的经典问题。朴素的数组操作(区间修改 ,区间查询 )在大量操作下会超时。

解决此问题的标准高效方法是使用 线段树 (Segment Tree),并结合 懒惰标记 (Lazy Propagation) 来处理区间修改。

  1. 线段树 线段树是一种二叉树,每个节点代表原数组的一个区间

    • 根节点代表整个区间
    • 对于任意节点 (其中 ),它的左子节点代表区间 ,右子节点代表区间 ,其中
    • 每个节点存储其所代表区间的元素之和 sum
  2. 懒惰标记 如果每次区间修改都更新到叶子节点,复杂度仍然是 。懒惰标记是优化区间修改的关键。

    • 每个节点额外维护一个 lazy_tag,表示“欠”其子孙节点的增量值。
    • 当一个修改操作 update(L, R, x) 完全覆盖了某个节点 p 所代表的区间 [l, r] 时,我们不再继续向下递归。我们直接更新节点 psumsum += x * (r - l + 1)),并将增量 累加到 plazy_tag 上。
    • 当后续操作(查询或修改)需要访问 p 的子节点时,我们先执行一次下推 (push_down) 操作:将 plazy_tag 值应用到其左右子节点的 sumlazy_tag 上,然后清空 plazy_tag

算法流程

  1. build ():根据初始数组递归构建线段树。
  2. update ():从根节点递归修改。对于部分重叠的区间,先 push_down 再递归子节点。
  3. query ():从根节点递归查询。对于部分重叠的区间,先 push_down 再递归子节点求和。

性能说明:由于本题数据量较大,为确保通过,Java 和 Python 的实现将采用更快的 I/O 方式。

代码

#include <iostream>
#include <vector>

using namespace std;

const int MAXN = 500005;
long long arr[MAXN];
long long tree[4 * MAXN];
long long lazy[4 * MAXN];
int n, m;

void push_up(int node) {
    tree[node] = tree[node * 2] + tree[node * 2 + 1];
}

void push_down(int node, int l, int r) {
    if (lazy[node] == 0) return;
    int mid = l + (r - l) / 2;
    // 更新左子节点的 sum 和 lazy_tag
    tree[node * 2] += lazy[node] * (mid - l + 1);
    lazy[node * 2] += lazy[node];
    // 更新右子节点的 sum 和 lazy_tag
    tree[node * 2 + 1] += lazy[node] * (r - (mid + 1) + 1);
    lazy[node * 2 + 1] += lazy[node];
    // 清空当前节点的 lazy_tag
    lazy[node] = 0;
}

void build(int node, int l, int r) {
    if (l == r) {
        tree[node] = arr[l];
        return;
    }
    int mid = l + (r - l) / 2;
    build(node * 2, l, mid);
    build(node * 2 + 1, mid + 1, r);
    push_up(node);
}

void update(int node, int l, int r, int ql, int qr, long long val) {
    if (ql <= l && r <= qr) {
        tree[node] += val * (r - l + 1);
        lazy[node] += val;
        return;
    }
    push_down(node, l, r);
    int mid = l + (r - l) / 2;
    if (ql <= mid) {
        update(node * 2, l, mid, ql, qr, val);
    }
    if (qr > mid) {
        update(node * 2 + 1, mid + 1, r, ql, qr, val);
    }
    push_up(node);
}

long long query(int node, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) {
        return tree[node];
    }
    push_down(node, l, r);
    int mid = l + (r - l) / 2;
    long long sum = 0;
    if (ql <= mid) {
        sum += query(node * 2, l, mid, ql, qr);
    }
    if (qr > mid) {
        sum += query(node * 2 + 1, mid + 1, r, ql, qr);
    }
    return sum;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> n >> m;
    for (int i = 1; i <= n; ++i) {
        cin >> arr[i];
    }
    build(1, 1, n);

    for (int i = 0; i < m; ++i) {
        int type;
        cin >> type;
        if (type == 1) {
            int l, r;
            long long x;
            cin >> l >> r >> x;
            update(1, 1, n, l, r, x);
        } else {
            int l, r;
            cin >> l >> r;
            cout << query(1, 1, n, l, r) << "\n";
        }
    }

    return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.IOException;
import java.util.StringTokenizer;

public class Main {
    static final int MAXN = 500005;
    static long[] arr = new long[MAXN];
    static long[] tree = new long[4 * MAXN];
    static long[] lazy = new long[4 * MAXN];
    static int n, m;

    static void pushUp(int node) {
        tree[node] = tree[node * 2] + tree[node * 2 + 1];
    }

    static void pushDown(int node, int l, int r) {
        if (lazy[node] == 0) return;
        int mid = l + (r - l) / 2;
        tree[node * 2] += lazy[node] * (mid - l + 1);
        lazy[node * 2] += lazy[node];
        tree[node * 2 + 1] += lazy[node] * (r - mid);
        lazy[node * 2 + 1] += lazy[node];
        lazy[node] = 0;
    }

    static void build(int node, int l, int r) {
        if (l == r) {
            tree[node] = arr[l];
            return;
        }
        int mid = l + (r - l) / 2;
        build(node * 2, l, mid);
        build(node * 2 + 1, mid + 1, r);
        pushUp(node);
    }

    static void update(int node, int l, int r, int ql, int qr, long val) {
        if (ql <= l && r <= qr) {
            tree[node] += val * (r - l + 1);
            lazy[node] += val;
            return;
        }
        pushDown(node, l, r);
        int mid = l + (r - l) / 2;
        if (ql <= mid) {
            update(node * 2, l, mid, ql, qr, val);
        }
        if (qr > mid) {
            update(node * 2 + 1, mid + 1, r, ql, qr, val);
        }
        pushUp(node);
    }

    static long query(int node, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) {
            return tree[node];
        }
        pushDown(node, l, r);
        int mid = l + (r - l) / 2;
        long sum = 0;
        if (ql <= mid) {
            sum += query(node * 2, l, mid, ql, qr);
        }
        if (qr > mid) {
            sum += query(node * 2 + 1, mid + 1, r, ql, qr);
        }
        return sum;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter writer = new PrintWriter(System.out);
        StringTokenizer st = new StringTokenizer(reader.readLine());
        n = Integer.parseInt(st.nextToken());
        m = Integer.parseInt(st.nextToken());

        st = new StringTokenizer(reader.readLine());
        for (int i = 1; i <= n; i++) {
            arr[i] = Long.parseLong(st.nextToken());
        }
        build(1, 1, n);

        for (int i = 0; i < m; i++) {
            st = new StringTokenizer(reader.readLine());
            int type = Integer.parseInt(st.nextToken());
            if (type == 1) {
                int l = Integer.parseInt(st.nextToken());
                int r = Integer.parseInt(st.nextToken());
                long x = Long.parseLong(st.nextToken());
                update(1, 1, n, l, r, x);
            } else {
                int l = Integer.parseInt(st.nextToken());
                int r = Integer.parseInt(st.nextToken());
                writer.println(query(1, 1, n, l, r));
            }
        }
        writer.flush();
    }
}
import sys

# 为应对大数据和深递归,增加递归深度并使用快速I/O
sys.setrecursionlimit(100005)
input = sys.stdin.readline

n, m = 0, 0
arr = []
tree = []
lazy = []

def push_up(node):
    tree[node] = tree[node * 2] + tree[node * 2 + 1]

def push_down(node, l, r):
    if lazy[node] == 0:
        return
    mid = l + (r - l) // 2
    tree[node * 2] += lazy[node] * (mid - l + 1)
    lazy[node * 2] += lazy[node]
    tree[node * 2 + 1] += lazy[node] * (r - mid)
    lazy[node * 2 + 1] += lazy[node]
    lazy[node] = 0

def build(node, l, r):
    if l == r:
        tree[node] = arr[l]
        return
    mid = l + (r - l) // 2
    build(node * 2, l, mid)
    build(node * 2 + 1, mid + 1, r)
    push_up(node)

def update(node, l, r, ql, qr, val):
    if ql <= l and r <= qr:
        tree[node] += val * (r - l + 1)
        lazy[node] += val
        return
    push_down(node, l, r)
    mid = l + (r - l) // 2
    if ql <= mid:
        update(node * 2, l, mid, ql, qr, val)
    if qr > mid:
        update(node * 2 + 1, mid + 1, r, ql, qr, val)
    push_up(node)

def query(node, l, r, ql, qr):
    if ql <= l and r <= qr:
        return tree[node]
    push_down(node, l, r)
    mid = l + (r - l) // 2
    s = 0
    if ql <= mid:
        s += query(node * 2, l, mid, ql, qr)
    if qr > mid:
        s += query(node * 2 + 1, mid + 1, r, ql, qr)
    return s

def solve():
    global n, m, arr, tree, lazy
    n, m = map(int, input().split())
    arr = [0] + list(map(int, input().split()))
    tree = [0] * (4 * (n + 1))
    lazy = [0] * (4 * (n + 1))
    
    build(1, 1, n)

    results = []
    for _ in range(m):
        line = list(map(int, input().split()))
        op_type = line[0]
        if op_type == 1:
            l, r, x = line[1], line[2], line[3]
            update(1, 1, n, l, r, x)
        else:
            l, r = line[1], line[2]
            results.append(str(query(1, 1, n, l, r)))
    
    if results:
        print("\n".join(results))

solve()

算法及复杂度

  • 算法:线段树 (Segment Tree) + 懒惰标记 (Lazy Propagation)
  • 时间复杂度。建树需要 ,后续 次操作每次需要
  • 空间复杂度,用于存储线段树。