题目链接

REAL776 移山

题目描述

给出 座山和它们的高度 。总共有 次移山操作。每次操作会指定一个连续的区间 和一个高度 ,将这个区间内所有山的高度都减去

问题是,在哪一次操作之后,首次出现至少一个山头的高度小于等于 0?题目保证答案一定存在。

解题思路

1. 核心观察:单调性

这个问题的核心性质是单调性。如果经过前 次操作后,某座山的高度已经小于等于 0,那么再进行第 次操作后,它的高度只会更低,必然也小于等于 0。

这种“一次满足,永久满足”的性质,是应用二分答案算法的经典标志。

2. 二分答案

我们可以对“操作的次数”进行二分查找。我们要在 的范围内,找到一个最小的次数 ,使得“进行前 次操作后,至少有一个山头的高度小于等于 0”这个条件首次成立。

为了实现二分查找,我们需要一个 check(k) 函数,它的功能是判断:进行前 次操作后,是否存在高度小于等于 0 的山头?

  • 如果 check(k) 返回 true,说明前 次操作已经足够(或超额)了,真正的答案可能就是 ,或者比 更小。因此,我们尝试在左半区间 中继续寻找,并记录下 这个潜在的答案。
  • 如果 check(k) 返回 false,说明前 次操作还不够,我们需要更多的操作。因此,答案一定在右半区间 中。

3. 高效实现 check(k):差分数组

check(k) 函数需要计算前 次操作对所有山的累积影响。如果朴素地模拟每一次区间减法,check(k) 的时间复杂度将是 ,这会导致整个算法超时。

为了高效地处理这 区间修改,我们可以使用差分数组

  • 构建差分: 我们创建一个差分数组 diff,长度为 。对于前 次操作中的每一次操作 ,我们执行 diff[l] += ddiff[r+1] -= d。这 次操作总共只需要 的时间。
  • 还原减去的总高度: 对差分数组 diff 求一次前缀和,就能得到一个新数组 total_reduction,其中 total_reduction[i] 表示第 座山在前 次操作中总共被减去的高度。这个过程需要 的时间。
  • 判断条件: 最后,我们遍历所有山,检查是否存在一个 使得 h[i] <= total_reduction[i]。如果存在,check(k) 返回 true,否则返回 false。这个过程需要 的时间。

综上,check(k) 函数的总时间复杂度被优化到了

4. 整体流程

  • 二分答案的范围是
  • 每次取中点 mid,调用 check(mid)
  • 根据 check(mid) 的结果,收缩二分查找的区间。
  • 最终找到的最小的满足条件的 k 即为答案。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <tuple>

using namespace std;

// 检查进行前 k 次操作后,是否有山的高度 <= 0
bool check(int k, int n, const vector<long long>& h, const vector<tuple<int, int, int>>& ops) {
    vector<long long> diff(n + 2, 0);
    for (int i = 0; i < k; ++i) {
        auto [l, r, d] = ops[i];
        diff[l] += d;
        diff[r + 1] -= d;
    }

    long long current_reduction = 0;
    for (int i = 1; i <= n; ++i) {
        current_reduction += diff[i];
        if (h[i] <= current_reduction) {
            return true;
        }
    }
    return false;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;

    vector<long long> h(n + 1);
    for (int i = 1; i <= n; ++i) {
        cin >> h[i];
    }

    vector<tuple<int, int, int>> ops(m);
    for (int i = 0; i < m; ++i) {
        cin >> get<0>(ops[i]) >> get<1>(ops[i]) >> get<2>(ops[i]);
    }

    int left = 1, right = m;
    int ans = m;

    while (left <= right) {
        int mid = left + (right - left) / 2;
        if (check(mid, n, h, ops)) {
            ans = mid;
            right = mid - 1;
        } else {
            left = mid + 1;
        }
    }

    cout << ans << endl;

    return 0;
}
import java.util.Scanner;

public class Main {
    static int n, m;
    static long[] h;
    static int[][] ops;

    // 检查进行前 k 次操作后,是否有山的高度 <= 0
    private static boolean check(int k) {
        long[] diff = new long[n + 2];
        for (int i = 0; i < k; i++) {
            int l = ops[i][0];
            int r = ops[i][1];
            int d = ops[i][2];
            diff[l] += d;
            diff[r + 1] -= d;
        }

        long currentReduction = 0;
        for (int i = 1; i <= n; i++) {
            currentReduction += diff[i];
            if (h[i] <= currentReduction) {
                return true;
            }
        }
        return false;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        n = sc.nextInt();
        m = sc.nextInt();

        h = new long[n + 1];
        for (int i = 1; i <= n; i++) {
            h[i] = sc.nextLong();
        }

        ops = new int[m][3];
        for (int i = 0; i < m; i++) {
            ops[i][0] = sc.nextInt();
            ops[i][1] = sc.nextInt();
            ops[i][2] = sc.nextInt();
        }

        int left = 1, right = m;
        int ans = m;

        while (left <= right) {
            int mid = left + (right - left) / 2;
            if (check(mid)) {
                ans = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }

        System.out.println(ans);
    }
}
import sys

def check(k, n, h, ops):
    """
    检查进行前 k 次操作后,是否有山的高度 <= 0
    """
    diff = [0] * (n + 2)
    for i in range(k):
        l, r, d = ops[i]
        diff[l] += d
        diff[r + 1] -= d
    
    current_reduction = 0
    for i in range(1, n + 1):
        current_reduction += diff[i]
        if h[i] <= current_reduction:
            return True
    return False

def solve():
    try:
        n, m = map(int, sys.stdin.readline().split())
        h = [0] + list(map(int, sys.stdin.readline().split()))
        ops = []
        for _ in range(m):
            ops.append(list(map(int, sys.stdin.readline().split())))
    except (IOError, ValueError):
        return

    left, right = 1, m
    ans = m

    while left <= right:
        mid = (left + right) // 2
        if check(mid, n, h, ops):
            ans = mid
            right = mid - 1
        else:
            left = mid + 1
            
    print(ans)

solve()

算法及复杂度

  • 算法: 二分答案 + 差分数组

  • 时间复杂度: 二分查找需要进行 次。每次 check 函数内部,构建差分数组需要 (其中 是当前二分的中点),还原总减少量并检查需要 。因此,单次 check 的复杂度为 。因为 最大可以取到 ,所以总的时间复杂度为

  • 空间复杂度: 需要存储山的高度,大小为 ;存储所有操作,大小为 ;差分数组的大小为 。因此,总的空间复杂度为