mex

思路

这道题乍一看像是模拟题——每轮算个 MEX,然后所有元素减掉它,直到全部相等。但直接模拟肯定超时,因为操作轮数可以达到 级别。关键在于找到规律,跳过大量重复操作。

什么时候无解?

如果数组里没有 0,那 MEX = 0,减 0 等于啥都没变。除非数组本身已经全部相等(直接输出 0),否则永远不会改变,输出 -1。

目标状态一定是全 0

一旦有元素变成 0,它就永远是 0(因为 )。所以如果最终要全部相等,只能全部等于 0。

核心观察:用原始数组的去重排序来推导 MEX

假设我们把原始数组排序去重得到 (其中 )。经过若干轮操作后,累计减去的总量为 ,那么此时数组中出现的不同值就是:

$$

也就是说,我们不需要真的去修改数组,只需要追踪 就能知道当前的 MEX。

快速计算 MEX

当前有效值集合里,0 一定存在。MEX 就是从 1 开始第一个缺失的正整数。翻译成原始数组的语言:从 开始,看 里有多少个连续的整数。如果 里包含 ,但不包含 ,那 MEX =

两种情况交替出现

  1. 遇到间隙(gap) 中下一个大于 的值 ,说明 不在 中,MEX = 1。连续 轮都是 MEX = 1,可以一步跳过。
  1. 遇到连续段(run) 开始有一段连续整数,长度为 。此时 MEX = ,一轮操作搞定,指针跳过这 个元素。

这样指针 上只会单调前进,总复杂度 (排序)+ (扫描),非常高效。

代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    cin >> n;
    vector<long long> a(n);
    for(int i = 0; i < n; i++) cin >> a[i];

    // 已经全部相等,直接输出 0
    bool allEqual = true;
    for(int i = 1; i < n; i++){
        if(a[i] != a[0]){ allEqual = false; break; }
    }
    if(allEqual){
        cout << 0 << endl;
        return 0;
    }

    sort(a.begin(), a.end());

    // 排序去重
    vector<long long> s;
    s.push_back(a[0]);
    for(int i = 1; i < n; i++){
        if(a[i] != a[i-1]) s.push_back(a[i]);
    }

    // 没有 0,MEX = 0,永远不变
    if(s[0] != 0){
        cout << -1 << endl;
        return 0;
    }

    long long maxVal = s.back();
    long long cumsum = 0;
    long long ops = 0;
    int j = 1; // s[0] = 0,从 s[1] 开始看

    while(cumsum < maxVal){
        // 找到第一个 > cumsum 的值
        while(j < (int)s.size() && s[j] <= cumsum) j++;
        if(j >= (int)s.size()) break;

        // 间隙部分:连续 MEX=1 的轮数
        long long gap = s[j] - cumsum - 1;
        ops += gap;
        cumsum += gap;

        // 连续段:找从 s[j] 开始的连续整数长度 L
        int L = 1;
        while(j + L < (int)s.size() && s[j + L] == s[j] + L) L++;

        // 这一轮 MEX = L+1
        ops += 1;
        cumsum += (long long)(L + 1);
        j += L;
    }

    cout << ops << endl;
    return 0;
}
import java.util.*;
import java.io.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());
        StringTokenizer st = new StringTokenizer(br.readLine().trim());
        long[] a = new long[n];
        for (int i = 0; i < n; i++) a[i] = Long.parseLong(st.nextToken());

        boolean allEqual = true;
        for (int i = 1; i < n; i++) {
            if (a[i] != a[0]) { allEqual = false; break; }
        }
        if (allEqual) {
            System.out.println(0);
            return;
        }

        Arrays.sort(a);

        ArrayList<Long> sList = new ArrayList<>();
        sList.add(a[0]);
        for (int i = 1; i < n; i++) {
            if (a[i] != a[i - 1]) sList.add(a[i]);
        }
        long[] s = new long[sList.size()];
        for (int i = 0; i < s.length; i++) s[i] = sList.get(i);

        if (s[0] != 0) {
            System.out.println(-1);
            return;
        }

        long maxVal = s[s.length - 1];
        long cumsum = 0;
        long ops = 0;
        int j = 1;

        while (cumsum < maxVal) {
            while (j < s.length && s[j] <= cumsum) j++;
            if (j >= s.length) break;

            long gap = s[j] - cumsum - 1;
            ops += gap;
            cumsum += gap;

            int L = 1;
            while (j + L < s.length && s[j + L] == s[j] + L) L++;

            ops += 1;
            cumsum += (long)(L + 1);
            j += L;
        }

        System.out.println(ops);
    }
}
import sys
input = sys.stdin.readline

def solve():
    n = int(input())
    a = list(map(int, input().split()))

    if len(set(a)) == 1:
        print(0)
        return

    a.sort()
    s = [a[0]]
    for i in range(1, n):
        if a[i] != a[i-1]:
            s.append(a[i])

    if s[0] != 0:
        print(-1)
        return

    max_val = s[-1]
    cumsum = 0
    ops = 0
    j = 1

    while cumsum < max_val:
        while j < len(s) and s[j] <= cumsum:
            j += 1
        if j >= len(s):
            break

        gap = s[j] - cumsum - 1
        ops += gap
        cumsum += gap

        L = 1
        while j + L < len(s) and s[j + L] == s[j] + L:
            L += 1

        ops += 1
        cumsum += L + 1
        j += L

    print(ops)

solve()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', line => lines.push(line.trim()));
rl.on('close', () => {
    const n = parseInt(lines[0]);
    const a = lines[1].split(' ').map(BigInt);

    let allEqual = true;
    for (let i = 1; i < n; i++) {
        if (a[i] !== a[0]) { allEqual = false; break; }
    }
    if (allEqual) { console.log('0'); return; }

    a.sort((x, y) => (x < y ? -1 : x > y ? 1 : 0));

    const s = [a[0]];
    for (let i = 1; i < n; i++) {
        if (a[i] !== a[i-1]) s.push(a[i]);
    }

    if (s[0] !== 0n) { console.log('-1'); return; }

    const maxVal = s[s.length - 1];
    let cumsum = 0n;
    let ops = 0n;
    let j = 1;

    while (cumsum < maxVal) {
        while (j < s.length && s[j] <= cumsum) j++;
        if (j >= s.length) break;

        const gap = s[j] - cumsum - 1n;
        ops += gap;
        cumsum += gap;

        let L = 1;
        while (j + L < s.length && s[j + L] === s[j] + BigInt(L)) L++;

        ops += 1n;
        cumsum += BigInt(L + 1);
        j += L;
    }

    console.log(ops.toString());
});

复杂度分析

  • 时间复杂度,排序是瓶颈,后续扫描只需
  • 空间复杂度,存储去重后的数组。