题目链接

【入门班】第k大与第m大

题目描述

给定一个长度为 的数组 。通过以下规则构造一个新数组 : 对于 中所有长度不小于 的连续子数组,找到该子数组的第 大元素,并将其加入 。 最终,需要输出数组 中第 大的元素。

解题思路

这是一个在隐式构造的巨大集合中查找第 大元素的问题。直接构造数组 的复杂度过高,无法接受。这类问题是二分答案的典型应用场景。

我们可以二分最终的答案,设为 。然后,我们需要一个 check(x) 函数来判定一个候选值 是否可行。对于“第 大”的查询,check(x) 函数需要回答:在数组 中,大于或等于 的元素是否至少有 个?

  • 如果 check(x) 为真,说明真实的第 大元素不小于 ,我们可以尝试更大的值,即 left = mid + 1
  • 如果 check(x) 为假,说明 太大了,真实的第 大元素比 小,即 right = mid - 1

check(x) 的核心是计算满足特定条件的子数组数量。一个子数组的元素会被加入 当且仅当它是其所在子数组(长度 )的第 大元素。我们需要计算,有多少个这样的元素值 。这等价于计算:有多少个长度不小于 的子数组,其第 大的元素值

这个条件可以进一步转化为一个更易于处理的形式:“一个子数组的第 大元素 等价于 “该子数组中至少有 个元素的值 ”。

现在问题变成了:给定一个值 ,如何快速计算数组 中有多少个长度不小于 的子数组,满足其内部 的元素个数不少于 个。 这可以用一个 的线性扫描方法解决:

  1. 我们从左到右遍历数组 的每一个位置 (作为子数组的右端点)。
  2. 对于每一个 ,我们需要计算有多少个合法的左端点 )能构成满足条件的子数组
  3. 一个子数组 满足条件,当且仅当它包含至少 的元素。
  4. 我们可以找到这样一个临界位置 :它是从右往左数,在 中第 的元素所在的位置。那么,任何以 为右端点,以 中任一下标为左端点的子数组,都必然满足条件。
  5. 因此,对于每个右端点 ,满足条件的子数组数量就是
  6. 为了在遍历 的过程中高效地找到这个 ,我们可以用一个队列来维护我们遇到的最近 的元素的下标。队列的头部就是我们需要的

通过这种方式,check(x) 可以在 时间内完成。整个算法的时间复杂度为 ,其中 是数组中元素的值域。

代码

#include <iostream>
#include <vector>
#include <queue>
#include <numeric>

using namespace std;

bool check(long long x, int n, int k, long long m, const vector<int>& a) {
    long long count = 0;
    queue<int> q;
    for (int i = 0; i < n; ++i) {
        if (a[i] >= x) {
            q.push(i);
        }
        if (q.size() > k) {
            q.pop();
        }
        if (q.size() == k) {
            count += (long long)(q.front() + 1);
        }
        if (count >= m) return true; // 提前退出优化
    }
    return count >= m;
}

void solve() {
    int n, k;
    long long m;
    cin >> n >> k >> m;
    vector<int> a(n);
    int max_val = 0;
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
        if (a[i] > max_val) {
            max_val = a[i];
        }
    }

    long long left = 1, right = max_val, ans = 1;
    while (left <= right) {
        long long mid = left + (right - left) / 2;
        if (mid == 0) { // 避免 mid 为 0
            left = 1;
            continue;
        }
        if (check(mid, n, k, m, a)) {
            ans = mid;
            left = mid + 1;
        } else {
            right = mid - 1;
        }
    }
    cout << ans << endl;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int t;
    cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}

import java.util.Scanner;
import java.util.Queue;
import java.util.LinkedList;

public class Main {
    private static boolean check(long x, int n, int k, long m, int[] a) {
        long count = 0;
        Queue<Integer> q = new LinkedList<>();
        for (int i = 0; i < n; i++) {
            if (a[i] >= x) {
                q.add(i);
            }
            if (q.size() > k) {
                q.poll();
            }
            if (q.size() == k) {
                count += (long)(q.peek() + 1);
            }
            if (count >= m) return true;
        }
        return count >= m;
    }

    private static void solve(Scanner sc) {
        int n = sc.nextInt();
        int k = sc.nextInt();
        long m = sc.nextLong();
        int[] a = new int[n];
        int maxVal = 0;
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
            if (a[i] > maxVal) {
                maxVal = a[i];
            }
        }

        long left = 1, right = maxVal, ans = 1;
        while (left <= right) {
            long mid = left + (right - left) / 2;
            if (mid == 0) {
                left = 1;
                continue;
            }
            if (check(mid, n, k, m, a)) {
                ans = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        System.out.println(ans);
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int t = sc.nextInt();
        while (t-- > 0) {
            solve(sc);
        }
    }
}

import sys
from collections import deque

def check(x, n, k, m, a):
    count = 0
    q = deque()
    for i in range(n):
        if a[i] >= x:
            q.append(i)
        if len(q) > k:
            q.popleft()
        if len(q) == k:
            count += q[0] + 1
        if count >= m:
            return True
    return count >= m

def solve():
    n, k, m = map(int, sys.stdin.readline().split())
    a = list(map(int, sys.stdin.readline().split()))

    left, right = 1, 0
    if a:
        right = max(a)
    
    ans = 1
    
    while left <= right:
        mid = (left + right) // 2
        if mid == 0:
            left = 1
            continue
        if check(mid, n, k, m, a):
            ans = mid
            left = mid + 1
        else:
            right = mid - 1
            
    print(ans)

def main():
    t_str = sys.stdin.readline()
    if not t_str:
        return
    t = int(t_str)
    for _ in range(t):
        solve()

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:二分答案 + 线性扫描(使用队列维护窗口)。
  • 时间复杂度,其中 是测试数据组数, 是数组大小, 是数组中元素的最大值。对于每组数据,二分查找需要 次,每次 check 函数的复杂度为
  • 空间复杂度,用于存储队列中的下标,在最坏情况下为