题目链接

山峰数组计数

题目描述

定义一个山峰数组为一个长度为 3 的数组 ,满足

给定一个长度为 的正整数数组 。你需要选择两个下标 ),并将数组 划分成三个非空的连续子数组,其元素和分别为:

若三元组 构成一个山峰数组,则称二元组 是一个可行的二元组。请计算共有多少个不同的可行二元组

输入:

  • 第一行输入一个整数 ()。
  • 第二行输入 个整数 ()。

输出:

  • 输出一个整数,表示可行二元组的数量。

解题思路

本题的核心是统计满足特定条件的分割点 的数量。注意到 的范围很大,一个 的暴力枚举解法会超时,我们需要一个更高效的算法。

我们可以使用前缀和来快速计算子数组的和,并通过二分查找来加速计数过程。

  1. 预处理前缀和: 首先,我们计算数组 的前缀和数组 psps[k] 存储 。为了方便计算,ps[0] 可以设为0。这样,三个部分的和可以表示为:

  2. 转换山峰条件: 山峰数组的条件 可以用前缀和表示:

  3. 迭代与二分查找: 我们可以固定第一个分割点 ,然后去寻找所有满足条件的第二个分割点

    • 外层循环遍历 ,范围从 (确保所有部分都非空)。
    • 对于每个固定的 ,我们需要找到有多少个 (其中 )同时满足上述两个条件。
    • 这两个条件可以合并为一个对 ps[j] 的要求。由于 long long 类型的除法可能丢失精度,我们比较 2 * ps[j]ps[n] + ps[i]ps[j] 必须大于 2 * ps[i],并且 ps[j] 必须大于 (ps[n] + ps[i]) / 2
    • 因为数组 的元素都是正整数,前缀和数组 ps 是严格单调递增的。我们可以对 的范围 进行二分查找,找到第一个满足两个条件的 ps[j]
    • threshold1 = 2 * ps[i]。我们在 ps 数组的 [i+1, n-1] 范围内找到第一个大于 threshold1 的位置 it1
    • threshold2_num = ps[n] + ps[i]。我们在此范围内找到第一个 ps[j] 使得 2 * ps[j] > threshold2_num 的位置 it2
    • 同时满足两个条件的第一个 的位置就是 max(it1, it2)。从这个位置开始到 的所有 都满足条件。
    • 将每个 对应的 的数量累加起来,即可得到最终答案。

该算法的总时间复杂度为 ,足以通过本题。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;
    vector<long long> ps(n + 1, 0);
    for (int i = 0; i < n; ++i) {
        int p;
        cin >> p;
        ps[i + 1] = ps[i] + p;
    }

    long long count = 0;
    for (int i = 1; i <= n - 2; ++i) {
        // Condition 1: ps[j] > 2 * ps[i]
        auto it1 = upper_bound(ps.begin() + i + 1, ps.begin() + n, 2 * ps[i]);
        
        // Condition 2: 2 * ps[j] > ps[n] + ps[i]
        long long threshold2_num = ps[n] + ps[i];
        auto it2 = lower_bound(ps.begin() + i + 1, ps.begin() + n, 0, 
            [&](long long val, long long dummy) {
                return 2 * val <= threshold2_num;
            });

        // Find the starting index for valid j's
        auto start_it = max(it1, it2);
        
        // Add the number of valid j's
        count += distance(start_it, ps.begin() + n);
    }

    cout << count << endl;

    return 0;
}
import java.util.Scanner;
import java.util.Arrays;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        long[] ps = new long[n + 1];
        for (int i = 0; i < n; i++) {
            ps[i + 1] = ps[i] + sc.nextInt();
        }

        long count = 0;
        for (int i = 1; i <= n - 2; i++) {
            // Binary search for condition 1: ps[j] > 2 * ps[i]
            int low1 = i + 1, high1 = n - 1, ans1 = n;
            long threshold1 = 2 * ps[i];
            while (low1 <= high1) {
                int mid = low1 + (high1 - low1) / 2;
                if (ps[mid] > threshold1) {
                    ans1 = mid;
                    high1 = mid - 1;
                } else {
                    low1 = mid + 1;
                }
            }

            // Binary search for condition 2: 2 * ps[j] > ps[n] + ps[i]
            int low2 = i + 1, high2 = n - 1, ans2 = n;
            long threshold2_num = ps[n] + ps[i];
            while (low2 <= high2) {
                int mid = low2 + (high2 - low2) / 2;
                if (2 * ps[mid] > threshold2_num) {
                    ans2 = mid;
                    high2 = mid - 1;
                } else {
                    low2 = mid + 1;
                }
            }
            
            int start_j = Math.max(ans1, ans2);
            if (start_j <= n - 1) {
                count += (n - 1) - start_j + 1;
            }
        }
        System.out.println(count);
    }
}
import bisect

def solve():
    n = int(input())
    p = list(map(int, input().split()))

    ps = [0] * (n + 1)
    for i in range(n):
        ps[i+1] = ps[i] + p[i]

    count = 0
    # Fix i from 1 to n-2 (0-based: 0 to n-3)
    for i in range(1, n - 1):
        # Condition 1: ps[j] > 2 * ps[i]
        threshold1 = 2 * ps[i]
        # Find first j in [i+1, n-1] s.t. ps[j] > threshold1
        # bisect_right finds insertion point for threshold1, which is the index of first element > threshold1
        idx1 = bisect.bisect_right(ps, threshold1, lo=i + 1, hi=n)
        
        # Condition 2: 2 * ps[j] > ps[n] + ps[i]
        threshold2_num = ps[n] + ps[i]
        # Find first j in [i+1, n-1] s.t. 2*ps[j] > threshold2_num
        # We need custom binary search for this
        low, high = i + 1, n
        idx2 = n
        while low < high:
            mid = (low + high) // 2
            if 2 * ps[mid] > threshold2_num:
                high = mid
            else:
                low = mid + 1
        idx2 = low

        start_j = max(idx1, idx2)
        
        if start_j < n:
            count += (n - 1) - start_j + 1

    print(count)

solve()

算法及复杂度

  • 算法:前缀和 + 二分查找
  • 时间复杂度:。预计算前缀和需要 。之后是一个循环,遍历 个可能的 值,在循环内部进行两次二分查找,每次耗时 。因此总复杂度为
  • 空间复杂度:,用于存储前缀和数组。