题目链接

最小的最小未出现自然数

题目描述

给定一个长度为 的整数序列 和一个整数 。对于所有满足 ,需要计算

最终,需要求出这些 值中的最小值。

【名词解释】 最小未出现的自然数(mex): 指的是不包含在集合 中的最小非负整数。

解题思路

本题要求计算所有长度为 的滑动窗口的 值,并找出这些 值中的最小值。

一个朴素的解法是遍历每一个滑动窗口,对窗口内的元素计算 值,然后取所有结果的最小值。对于每个窗口,计算 的时间复杂度约为 。总共有 个窗口,因此总时间复杂度为 ,对于给定的数据范围会超时。

我们可以转换问题的角度。设最终答案为 ,即所有窗口的 值的最小值为 。这等价于以下两个条件同时成立:

  1. 对于任意一个窗口 ,都有
  2. 至少存在一个窗口 ,满足

条件 1 “对于任意窗口 ” 等价于 “每一个长度为 的窗口都包含了集合 中的所有整数”。

我们可以发现,如果“每一个窗口都包含 ”这个性质成立,那么对于 ,“每一个窗口都包含 ”也必然成立。这个单调性启发我们使用二分答案

我们可以二分查找最终的答案 。对于一个候选答案 mid,我们需要一个 check(mid) 函数来判断是否所有长度为 的窗口都包含了

  • 如果 check(mid) 为真,说明所有窗口的 值都至少为 mid。这意味着最终答案可能就是 mid,或者可能更大。我们尝试搜索更大的答案,即 low = mid + 1
  • 如果 check(mid) 为假,说明至少存在一个窗口,它的 值小于 mid。因此,最终答案必然小于 mid,我们缩减搜索范围,即 high = mid - 1

check(mid) 的实现是关键。要判断所有窗口是否都包含 ,我们只需要对该范围内的每一个数 进行检查,确保它没有在任何一个长度为 的窗口中缺席。

一个数 在某个窗口中缺席,等价于该窗口完全位于 的两次连续出现之间(或者在第一次出现之前,或在最后一次出现之后)。

  • 我们可以预处理出每个数字 在数组 中出现的所有位置。
  • 对于每个数字 ,计算出其连续出现位置之间的最大间隔。如果这个最大间隔大于 ,则说明存在一个长度为 的窗口可以“塞”在这个间隔里,从而不包含数字
  • 因此,check(mid) 的逻辑就是:对于 ,计算它们各自的最大间隔,如果其中任何一个的最大间隔大于 ,则 check(mid) 返回假。否则返回真。

为了使 check(mid) 的效率达到 ,我们可以进行预处理:

  1. 计算出每个数字 的最大间隔 max_gap[j]
  2. 计算 max_gap 数组的前缀最大值 prefix_max_gap[k] = max(max_gap[0], ..., max_gap[k])
  3. 这样,check(mid) 就变成了判断 prefix_max_gap[mid-1] <= m 是否成立。

总时间复杂度为 ,即

代码

#include <iostream>
#include <vector>
#include <algorithm>

int main() {
    using namespace std;
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;
    vector<vector<int>> pos(n + 1);
    for (int i = 0; i < n; ++i) {
        int val;
        cin >> val;
        if (val <= n) {
            pos[val].push_back(i);
        }
    }

    vector<int> max_gaps(n + 2);
    for (int j = 0; j <= n; ++j) {
        int last_pos = -1;
        int max_dist = 0;
        for (int p : pos[j]) {
            max_dist = max(max_dist, p - last_pos);
            last_pos = p;
        }
        max_dist = max(max_dist, n - last_pos);
        max_gaps[j] = max_dist;
    }

    vector<int> prefix_max_gaps(n + 2, 0);
    prefix_max_gaps[0] = max_gaps[0];
    for (int i = 1; i <= n + 1; ++i) {
        prefix_max_gaps[i] = max(prefix_max_gaps[i - 1], max_gaps[i]);
    }

    int low = 1, high = n + 2, ans = 0;
    while (low < high) {
        int mid = low + (high - low) / 2;
        if (prefix_max_gaps[mid - 1] <= m) {
            ans = mid;
            low = mid + 1;
        } else {
            high = mid;
        }
    }

    cout << ans << endl;

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.List;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        
        List<Integer>[] pos = new ArrayList[n + 1];
        for (int i = 0; i <= n; i++) {
            pos[i] = new ArrayList<>();
        }

        for (int i = 0; i < n; i++) {
            int val = sc.nextInt();
            if (val <= n) {
                pos[val].add(i);
            }
        }

        int[] maxGaps = new int[n + 2];
        for (int j = 0; j <= n; j++) {
            int lastPos = -1;
            int maxDist = 0;
            for (int p : pos[j]) {
                maxDist = Math.max(maxDist, p - lastPos);
                lastPos = p;
            }
            maxDist = Math.max(maxDist, n - lastPos);
            maxGaps[j] = maxDist;
        }

        int[] prefixMaxGaps = new int[n + 2];
        prefixMaxGaps[0] = maxGaps[0];
        for (int i = 1; i <= n + 1; i++) {
            prefixMaxGaps[i] = Math.max(prefixMaxGaps[i - 1], maxGaps[i]);
        }

        int low = 1, high = n + 2, ans = 0;
        while (low < high) {
            int mid = low + (high - low) / 2;
            if (mid > 0 && prefixMaxGaps[mid - 1] <= m) {
                ans = mid;
                low = mid + 1;
            } else {
                high = mid;
            }
        }
        
        System.out.println(ans);
    }
}
import sys

def main():
    n, m = map(int, sys.stdin.readline().split())
    a = list(map(int, sys.stdin.readline().split()))
    
    pos = [[] for _ in range(n + 1)]
    for i, val in enumerate(a):
        if val <= n:
            pos[val].append(i)
            
    max_gaps = [0] * (n + 2)
    for j in range(n + 1):
        last_pos = -1
        max_dist = 0
        for p in pos[j]:
            max_dist = max(max_dist, p - last_pos)
            last_pos = p
        max_dist = max(max_dist, n - last_pos)
        max_gaps[j] = max_dist
        
    prefix_max_gaps = [0] * (n + 2)
    prefix_max_gaps[0] = max_gaps[0]
    for i in range(1, n + 2):
        prefix_max_gaps[i] = max(prefix_max_gaps[i - 1], max_gaps[i])

    low, high = 1, n + 2
    ans = 0
    while low < high:
        mid = (low + high) // 2
        if mid > 0 and prefix_max_gaps[mid - 1] <= m:
            ans = mid
            low = mid + 1
        else:
            high = mid
            
    print(ans)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法: 二分答案、预处理
  • 时间复杂度: 。预处理每个数字的位置、计算最大间隔、计算前缀最大值都是 。二分答案需要 次,但每次检查是 。总复杂度由预处理决定。
  • 空间复杂度: ,用于存储每个数字出现的位置。