题目链接
题目描述
给定一个长度为 的数组
。通过以下规则构造一个新数组
:
对于
中所有长度不小于
的连续子数组,找到该子数组的第
大元素,并将其加入
。
最终,需要输出数组
中第
大的元素。
解题思路
这是一个在隐式构造的巨大集合中查找第 大元素的问题。直接构造数组
的复杂度过高,无法接受。这类问题是二分答案的典型应用场景。
我们可以二分最终的答案,设为 。然后,我们需要一个
check(x)
函数来判定一个候选值 是否可行。对于“第
大”的查询,
check(x)
函数需要回答:在数组 中,大于或等于
的元素是否至少有
个?
- 如果
check(x)
为真,说明真实的第大元素不小于
,我们可以尝试更大的值,即
left = mid + 1
。 - 如果
check(x)
为假,说明太大了,真实的第
大元素比
小,即
right = mid - 1
。
check(x)
的核心是计算满足特定条件的子数组数量。一个子数组的元素会被加入 当且仅当它是其所在子数组(长度
)的第
大元素。我们需要计算,有多少个这样的元素值
。这等价于计算:有多少个长度不小于
的子数组,其第
大的元素值
?
这个条件可以进一步转化为一个更易于处理的形式:“一个子数组的第 大元素
” 等价于 “该子数组中至少有
个元素的值
”。
现在问题变成了:给定一个值 ,如何快速计算数组
中有多少个长度不小于
的子数组,满足其内部
的元素个数不少于
个。
这可以用一个
的线性扫描方法解决:
- 我们从左到右遍历数组
的每一个位置
(作为子数组的右端点)。
- 对于每一个
,我们需要计算有多少个合法的左端点
(
)能构成满足条件的子数组
。
- 一个子数组
满足条件,当且仅当它包含至少
个
的元素。
- 我们可以找到这样一个临界位置
:它是从右往左数,在
中第
个
的元素所在的位置。那么,任何以
为右端点,以
中任一下标为左端点的子数组,都必然满足条件。
- 因此,对于每个右端点
,满足条件的子数组数量就是
。
- 为了在遍历
的过程中高效地找到这个
,我们可以用一个队列来维护我们遇到的最近
个
的元素的下标。队列的头部就是我们需要的
。
通过这种方式,check(x)
可以在 时间内完成。整个算法的时间复杂度为
,其中
是数组中元素的值域。
代码
#include <iostream>
#include <vector>
#include <queue>
#include <numeric>
using namespace std;
bool check(long long x, int n, int k, long long m, const vector<int>& a) {
long long count = 0;
queue<int> q;
for (int i = 0; i < n; ++i) {
if (a[i] >= x) {
q.push(i);
}
if (q.size() > k) {
q.pop();
}
if (q.size() == k) {
count += (long long)(q.front() + 1);
}
if (count >= m) return true; // 提前退出优化
}
return count >= m;
}
void solve() {
int n, k;
long long m;
cin >> n >> k >> m;
vector<int> a(n);
int max_val = 0;
for (int i = 0; i < n; ++i) {
cin >> a[i];
if (a[i] > max_val) {
max_val = a[i];
}
}
long long left = 1, right = max_val, ans = 1;
while (left <= right) {
long long mid = left + (right - left) / 2;
if (mid == 0) { // 避免 mid 为 0
left = 1;
continue;
}
if (check(mid, n, k, m, a)) {
ans = mid;
left = mid + 1;
} else {
right = mid - 1;
}
}
cout << ans << endl;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int t;
cin >> t;
while (t--) {
solve();
}
return 0;
}
import java.util.Scanner;
import java.util.Queue;
import java.util.LinkedList;
public class Main {
private static boolean check(long x, int n, int k, long m, int[] a) {
long count = 0;
Queue<Integer> q = new LinkedList<>();
for (int i = 0; i < n; i++) {
if (a[i] >= x) {
q.add(i);
}
if (q.size() > k) {
q.poll();
}
if (q.size() == k) {
count += (long)(q.peek() + 1);
}
if (count >= m) return true;
}
return count >= m;
}
private static void solve(Scanner sc) {
int n = sc.nextInt();
int k = sc.nextInt();
long m = sc.nextLong();
int[] a = new int[n];
int maxVal = 0;
for (int i = 0; i < n; i++) {
a[i] = sc.nextInt();
if (a[i] > maxVal) {
maxVal = a[i];
}
}
long left = 1, right = maxVal, ans = 1;
while (left <= right) {
long mid = left + (right - left) / 2;
if (mid == 0) {
left = 1;
continue;
}
if (check(mid, n, k, m, a)) {
ans = mid;
left = mid + 1;
} else {
right = mid - 1;
}
}
System.out.println(ans);
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int t = sc.nextInt();
while (t-- > 0) {
solve(sc);
}
}
}
import sys
from collections import deque
def check(x, n, k, m, a):
count = 0
q = deque()
for i in range(n):
if a[i] >= x:
q.append(i)
if len(q) > k:
q.popleft()
if len(q) == k:
count += q[0] + 1
if count >= m:
return True
return count >= m
def solve():
n, k, m = map(int, sys.stdin.readline().split())
a = list(map(int, sys.stdin.readline().split()))
left, right = 1, 0
if a:
right = max(a)
ans = 1
while left <= right:
mid = (left + right) // 2
if mid == 0:
left = 1
continue
if check(mid, n, k, m, a):
ans = mid
left = mid + 1
else:
right = mid - 1
print(ans)
def main():
t_str = sys.stdin.readline()
if not t_str:
return
t = int(t_str)
for _ in range(t):
solve()
if __name__ == "__main__":
main()
算法及复杂度
- 算法:二分答案 + 线性扫描(使用队列维护窗口)。
- 时间复杂度:
,其中
是测试数据组数,
是数组大小,
是数组中元素的最大值。对于每组数据,二分查找需要
次,每次
check
函数的复杂度为。
- 空间复杂度:
,用于存储队列中的下标,在最坏情况下为
。