题目链接

小红的纸牌游戏

题目描述

小红和小紫玩一个纸牌游戏。牌堆由 张牌组成,每张牌上是数字 '0' 或 '1'。两人轮流从牌堆中拿走一张牌,小红先手。当牌堆剩下 张牌时游戏结束。这 张牌按原先的相对顺序组成一个二进制数。

小红的目标是使这个二进制数尽可能大,而小紫的目标是使其尽可能小。假设双方都采取最优策略,请问最后剩下的二进制数是多少?

解题思路

这是一个典型的博弈问题。由于双方都采取最优策略,我们需要分析每个玩家的目标以及他们会如何行动来实现这个目标。

玩家目标与策略

  • 小红 (Maximizer): 为了让最终的二进制数尽可能大,她希望最终结果的高位(左边)尽可能为 '1'。因此,她的首要策略是保留 '1'移除 '0'
  • 小紫 (Minimizer): 为了让最终的二进制数尽可能小,她希望最终结果的高位(左边)尽可能为 '0'。因此,她的首要策略是保留 '0'移除 '1'

关键洞察

游戏总共需要移除 张牌。小红先手,所以她会拿 张牌,小紫会拿 张牌。

由于小红和小紫的首要目标是互斥的(移除 '0' vs 移除 '1'),游戏过程可以被简化。我们不需要进行回合制的模拟。我们可以直接计算出在整个游戏过程中,小红和小紫分别会移除多少个 '0' 和 '1'。

  1. 优先移除:

    • 在小红的 个回合中,只要场上还有 '0',她就会优先拿走 '0'。
    • 在小紫的 个回合中,只要场上还有 '1',她就会优先拿走 '1'。
  2. 强制移除:

    • 如果轮到小红拿牌,但场上已经没有 '0' 了,她将被迫拿走一张 '1'。
    • 如果轮到小紫拿牌,但场上已经没有 '1' 了,她将被迫拿走一张 '0'。

移除数量计算

设牌堆中初始有 count_0 个 '0' 和 count_1 个 '1'。

  • 小红优先移除的 '0' 的数量为 zeros_removed_by_hong = min(count_0, h)
  • 小紫优先移除的 '1' 的数量为 ones_removed_by_zi = min(count_1, z)
  • 小红被迫移除的 '1' 的数量为 ones_removed_by_hong_forced = max(0, h - count_0)
  • 小紫被迫移除的 '0' 的数量为 zeros_removed_by_zi_forced = max(0, z - count_1)

移除位置选择

现在我们知道了总共要移除多少个 '0' 和 '1',还需要确定移除哪些位置的牌。

  • 小红移除 '0' (优先): 为了让结果最大化,她需要移除那些可能成为高位的 '0'。因此,她会拿走最左边的 '0'。
  • 小紫移除 '1' (优先): 为了让结果最小化,她需要移除那些可能成为高位的 '1'。因此,她会拿走最左边的 '1'。
  • 小红移除 '1' (被迫): 她被迫移除一张她想保留的 '1'。为了使损失最小,她会移除对结果数值影响最小的 '1',即最右边的 '1'。
  • 小紫移除 '0' (被迫): 她被迫移除一张她想保留的 '0'。为了使损失最小,她会移除对结果数值影响最小的 '0',即最右边的 '0'。

最终算法

  1. 计算总共需要移除的牌数 rem = n - k,以及小红和小紫的拿牌次数 hz
  2. 遍历一遍字符串,统计所有 '0' 和 '1' 的位置,分别存入两个列表 zero_indicesone_indices
  3. 根据上面的公式,计算出四种移除情况各自对应的数量。
  4. 根据移除位置选择策略,确定所有要被移除的牌的原始索引,并将它们放入一个集合中以便快速查找。
    • zero_indices 的前 zeros_removed_by_hong 个。
    • one_indices 的前 ones_removed_by_zi 个。
    • zero_indices 的后 zeros_removed_by_zi_forced 个。
    • one_indices 的后 ones_removed_by_hong_forced 个。
  5. 再次遍历原始字符串(从索引 0 到 n-1),如果当前索引不在移除集合中,则将该字符追加到结果字符串中。
  6. 输出结果字符串。

该算法的时间和空间复杂度均为

代码

#include <iostream>
#include <string>
#include <vector>
#include <numeric>
#include <algorithm>
#include <cmath>
#include <set>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, k;
    cin >> n >> k;
    string s;
    cin >> s;

    int rem = n - k;
    if (rem == 0) {
        cout << s << endl;
        return 0;
    }

    int hong_turns = (rem + 1) / 2;
    int zi_turns = rem / 2;

    vector<int> zero_indices, one_indices;
    for (int i = 0; i < n; ++i) {
        if (s[i] == '0') {
            zero_indices.push_back(i);
        } else {
            one_indices.push_back(i);
        }
    }

    int count_0 = zero_indices.size();
    int count_1 = one_indices.size();

    int zeros_by_hong = min(count_0, hong_turns);
    int ones_by_zi = min(count_1, zi_turns);
    int ones_by_hong_forced = max(0, hong_turns - count_0);
    int zeros_by_zi_forced = max(0, zi_turns - count_1);

    set<int> removed_indices;

    for (int i = 0; i < zeros_by_hong; ++i) {
        removed_indices.insert(zero_indices[i]);
    }
    for (int i = 0; i < ones_by_zi; ++i) {
        removed_indices.insert(one_indices[i]);
    }
    for (int i = 0; i < ones_by_hong_forced; ++i) {
        removed_indices.insert(one_indices[count_1 - 1 - i]);
    }
    for (int i = 0; i < zeros_by_zi_forced; ++i) {
        removed_indices.insert(zero_indices[count_0 - 1 - i]);
    }

    string result = "";
    for (int i = 0; i < n; ++i) {
        if (removed_indices.find(i) == removed_indices.end()) {
            result += s[i];
        }
    }

    cout << result << endl;

    return 0;
}
import java.util.Scanner;
import java.util.List;
import java.util.ArrayList;
import java.util.Set;
import java.util.HashSet;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int k = sc.nextInt();
        String s = sc.next();

        int rem = n - k;
        if (rem == 0) {
            System.out.println(s);
            return;
        }

        int hongTurns = (rem + 1) / 2;
        int ziTurns = rem / 2;

        List<Integer> zeroIndices = new ArrayList<>();
        List<Integer> oneIndices = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            if (s.charAt(i) == '0') {
                zeroIndices.add(i);
            } else {
                oneIndices.add(i);
            }
        }

        int count0 = zeroIndices.size();
        int count1 = oneIndices.size();

        int zerosByHong = Math.min(count0, hongTurns);
        int onesByZi = Math.min(count1, ziTurns);
        int onesByHongForced = Math.max(0, hongTurns - count0);
        int zerosByZiForced = Math.max(0, ziTurns - count1);

        Set<Integer> removedIndices = new HashSet<>();
        for (int i = 0; i < zerosByHong; i++) {
            removedIndices.add(zeroIndices.get(i));
        }
        for (int i = 0; i < onesByZi; i++) {
            removedIndices.add(oneIndices.get(i));
        }
        for (int i = 0; i < onesByHongForced; i++) {
            removedIndices.add(oneIndices.get(count1 - 1 - i));
        }
        for (int i = 0; i < zerosByZiForced; i++) {
            removedIndices.add(zeroIndices.get(count0 - 1 - i));
        }

        StringBuilder result = new StringBuilder();
        for (int i = 0; i < n; i++) {
            if (!removedIndices.contains(i)) {
                result.append(s.charAt(i));
            }
        }

        System.out.println(result.toString());
    }
}
import sys

def solve():
    n, k = map(int, sys.stdin.readline().split())
    s = sys.stdin.readline().strip()

    rem = n - k
    if rem == 0:
        print(s)
        return

    hong_turns = (rem + 1) // 2
    zi_turns = rem // 2

    zero_indices = [i for i, char in enumerate(s) if char == '0']
    one_indices = [i for i, char in enumerate(s) if char == '1']
    
    count_0 = len(zero_indices)
    count_1 = len(one_indices)

    zeros_by_hong = min(count_0, hong_turns)
    ones_by_zi = min(count_1, zi_turns)
    ones_by_hong_forced = max(0, hong_turns - count_0)
    zeros_by_zi_forced = max(0, zi_turns - count_1)
    
    removed_indices = set()
    
    removed_indices.update(zero_indices[:zeros_by_hong])
    removed_indices.update(one_indices[:ones_by_zi])
    removed_indices.update(one_indices[count_1 - ones_by_hong_forced:])
    removed_indices.update(zero_indices[count_0 - zeros_by_zi_forced:])
    
    result = []
    for i, char in enumerate(s):
        if i not in removed_indices:
            result.append(char)
            
    print("".join(result))

solve()

算法及复杂度

  • 算法:博弈、贪心
  • 时间复杂度:算法主要包括遍历字符串以收集索引、计算移除数量和最后构建结果字符串。所有这些步骤都是线性的。因此,总时间复杂度为
  • 空间复杂度:需要额外的空间来存储 '0' 和 '1' 的索引,以及被移除牌的索引集合。在最坏的情况下,这需要 的空间。