题目链接
题目描述
给定一个长度为 的正整数数组
。你需要选择两个下标
(
),并将
划分成三个非空连续子数组:
,
,
。
若这三个子数组的和满足 且
,则称二元组
可行。请计算共有多少个不同的可行二元组。
解题思路
这是一个计数问题。一个朴素的想法是遍历所有可能的 对,然后计算三个子数组的和并进行比较。这会导致
的时间复杂度,对于本题的数据范围来说太慢。
为了优化,我们可以使用前缀和来快速计算子数组的和。设 为数组
前
个元素的和(
)。那么:
“山峰”条件可以转化为关于前缀和的不等式:
现在,问题转化为寻找满足 且同时满足上述两个不等式的整数对
的数量。
的遍历仍然不可行。我们可以采用枚举
,快速计算
的策略。
当我们固定一个
(
的范围为
),我们需要找到有多少个
(
的范围为
)满足:
由于数组 的元素都是正整数,前缀和数组
是一个严格单调递增的序列。这意味着对于一个固定的
,我们可以在有序的
上使用二分查找来快速地找到满足
的
的数量。
最终算法:
- 计算数组
的前缀和数组
。
- 初始化总数
。
- 遍历
从
到
: a. 计算阈值
。这里使用整数运算来处理严格小于的不等式。 b. 在
上使用二分查找(
lower_bound
或upper_bound
),找到满足的
的个数。 c. 将这个数量累加到
中。
这个算法的总时间复杂度为 。
代码
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<long long> p(n + 1, 0);
for (int i = 1; i <= n; ++i) {
long long val;
cin >> val;
p[i] = p[i - 1] + val;
}
long long ans = 0;
long long total_sum = p[n];
for (int j = 2; j < n; ++j) {
// 条件1: 2 * p[i] < p[j] => p[i] <= (p[j] - 1) / 2
long long threshold1 = (p[j] - 1) / 2;
// 条件2: p[i] < 2 * p[j] - total_sum => p[i] <= 2 * p[j] - total_sum - 1
long long threshold2 = 2 * p[j] - total_sum - 1;
long long upper_bound_val = min(threshold1, threshold2);
// 在 p[1...j-1] 中找有多少个 p[i] <= upper_bound_val
// lower_bound 找到第一个 >= val 的位置。
// upper_bound 找到第一个 > val 的位置。
// distance(begin, upper_bound(begin, end, val)) gives count of elements <= val
auto it = upper_bound(p.begin() + 1, p.begin() + j, upper_bound_val);
long long count = distance(p.begin() + 1, it);
ans += count;
}
cout << ans << endl;
return 0;
}
import java.util.Scanner;
import java.util.Arrays;
public class Main {
// 在 arr 的 [left, right] 区间内找第一个大于 target 的元素的索引
private static int upperBound(long[] arr, int left, int right, long target) {
int count = right - left;
while (count > 0) {
int step = count / 2;
int mid = left + step;
if (arr[mid] <= target) {
left = mid + 1;
count -= (step + 1);
} else {
count = step;
}
}
return left;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
long[] p = new long[n + 1];
p[0] = 0;
for (int i = 1; i <= n; i++) {
p[i] = p[i - 1] + sc.nextInt();
}
long ans = 0;
long totalSum = p[n];
for (int j = 2; j < n; j++) {
long threshold1 = (p[j] - 1) / 2;
long threshold2 = 2 * p[j] - totalSum - 1;
long upperBoundVal = Math.min(threshold1, threshold2);
if (upperBoundVal < p[1]) {
continue;
}
int pos = upperBound(p, 1, j, upperBoundVal);
long count = pos - 1;
ans += count;
}
System.out.println(ans);
}
}
import bisect
def main():
n = int(input())
a = list(map(int, input().split()))
p = [0] * (n + 1)
for i in range(n):
p[i + 1] = p[i] + a[i]
ans = 0
total_sum = p[n]
for j in range(2, n):
# 条件1: 2 * p[i] < p[j] => p[i] <= (p[j] - 1) // 2
threshold1 = (p[j] - 1) // 2
# 条件2: p[i] < 2 * p[j] - total_sum => p[i] <= 2 * p[j] - total_sum - 1
threshold2 = 2 * p[j] - total_sum - 1
upper_bound_val = min(threshold1, threshold2)
# 在 p[1...j-1] 中找有多少个 p[i] <= upper_bound_val
# bisect_right 找到第一个 > val 的位置的索引
# 由于 p 是 1-indexed (概念上), p[1] 在 python list 中是 index 1
# 在 p[1...j] (exclusive) 中查找
count = bisect.bisect_right(p, upper_bound_val, lo=1, hi=j) - 1
if count > 0:
ans += count
print(ans)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:前缀和 + 二分查找。
- 时间复杂度:
。预处理前缀和需要
。主循环遍历
从
到
,共
次。每次循环内部进行一次二分查找,耗时
。
- 空间复杂度:
,用于存储前缀和数组。