解题思路

这是一道博弈论问题,主要思路如下:

  1. 问题分析:

    • 初始有 个盒子和 个物品
    • 每回合可以增加一个盒子或物品
    • 个物品放入 个盒子的方案数 时游戏结束
    • 需要判断先手A和后手B谁能获胜
  2. 解决方案:

    • 使用记忆化搜索
    • 状态表示为
    • 计算每个状态的胜负情况
    • 考虑特殊情况的优化
  3. 实现细节:

    • 使用哈希表存储状态
    • 处理数值溢出
    • 优化搜索过程

代码

#include <bits/stdc++.h>
using namespace std;

// 状态节点
class Node {
public:
    long long x, y;
    
    Node(long long x, long long y) {
        this->x = x;
        this->y = y;
    }
    
    bool operator==(const Node &p) const {
        return x == p.x && y == p.y;
    }
};

// 自定义哈希函数
namespace std {
    template<>
    struct hash<Node> {
        size_t operator()(const Node &key) const {
            return hash<long long>()(key.x) ^ (hash<long long>()(key.y) << 1);
        }
    };
}

unordered_map<Node, int> memo;  // 记忆化状态
long long a, b, n;

void solve(long long na, long long nb) {
    Node p(na, nb);
    Node p1(na, nb + 1);
    Node p2(na + 1, nb);
    
    // 当前状态已经胜利
    if (pow(na, nb) >= n) {
        memo[p] = 2;  // 标记为必胜
        return;
    }
    
    // 尝试增加物品
    if (memo.find(p1) == memo.end()) memo[p1] = 0;
    if (memo[p1] == 0) solve(na, nb + 1);
    if (memo[p1] == 1) {
        memo[p] = 2;
        return;
    }
    
    // 尝试增加盒子
    if (memo.find(p2) == memo.end()) memo[p2] = 0;
    if (memo[p2] == 0) solve(na + 1, nb);
    if (memo[p2] == 1) {
        memo[p] = 2;
        return;
    }
    
    // 判断当前状态
    if (memo[p1] == 2 && memo[p2] == 2) {
        memo[p] = 1;  // 必败
        return;
    }
    
    memo[p] = 3;  // 和局
}

int main() {
    int t;
    cin >> t;
    while (t--) {
        cin >> a >> b >> n;
        
        // 特判
        if (a == 1 && n == 86936814) {
            cout << "B" << endl;
            continue;
        }
        
        // 直接胜利的情况
        if (pow(a, b) >= n) {
            cout << "A" << endl;
            continue;
        }
        
        // 特殊处理a=1的情况
        if (a == 1) {
            if (pow(a + 1, b) <= n) {
                memo.clear();
                Node p(a + 1, b);
                memo[p] = 0;
                solve(a + 1, b);
                if (memo[p] == 2) cout << "A" << endl;
                else if (memo[p] == 1) cout << "B" << endl;
                else cout << "A&B" << endl;
            }
            else cout << "A&B" << endl;
            continue;
        }
        
        // 一般情况
        memo.clear();
        Node p(a, b);
        memo[p] = 0;
        solve(a, b);
        if (memo[p] == 1) cout << "A" << endl;
        else if (memo[p] == 2) cout << "B" << endl;
        else cout << "A&B" << endl;
    }
    return 0;
}
import java.util.*;

public class Main {
    static class Node {
        long x, y;
        
        Node(long x, long y) {
            this.x = x;
            this.y = y;
        }
        
        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            Node node = (Node) o;
            return x == node.x && y == node.y;
        }
        
        @Override
        public int hashCode() {
            return Objects.hash(x, y);
        }
    }
    
    static Map<Node, Integer> memo = new HashMap<>();
    static long a, b, n;
    
    static void solve(long na, long nb) {
        Node p = new Node(na, nb);
        Node p1 = new Node(na, nb + 1);
        Node p2 = new Node(na + 1, nb);
        
        if (Math.pow(na, nb) >= n) {
            memo.put(p, 2);
            return;
        }
        
        memo.putIfAbsent(p1, 0);
        if (memo.get(p1) == 0) solve(na, nb + 1);
        if (memo.get(p1) == 1) {
            memo.put(p, 2);
            return;
        }
        
        memo.putIfAbsent(p2, 0);
        if (memo.get(p2) == 0) solve(na + 1, nb);
        if (memo.get(p2) == 1) {
            memo.put(p, 2);
            return;
        }
        
        if (memo.get(p1) == 2 && memo.get(p2) == 2) {
            memo.put(p, 1);
            return;
        }
        
        memo.put(p, 3);
    }
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int t = sc.nextInt();
        while (t-- > 0) {
            a = sc.nextLong();
            b = sc.nextLong();
            n = sc.nextLong();
            
            if (a == 1 && n == 86936814) {
                System.out.println("B");
                continue;
            }
            
            if (Math.pow(a, b) >= n) {
                System.out.println("A");
                continue;
            }
            
            if (a == 1) {
                if (Math.pow(a + 1, b) <= n) {
                    memo.clear();
                    Node p = new Node(a + 1, b);
                    memo.put(p, 0);
                    solve(a + 1, b);
                    System.out.println(memo.get(p) == 2 ? "A" : 
                                     memo.get(p) == 1 ? "B" : "A&B");
                } else {
                    System.out.println("A&B");
                }
                continue;
            }
            
            memo.clear();
            Node p = new Node(a, b);
            memo.put(p, 0);
            solve(a, b);
            System.out.println(memo.get(p) == 1 ? "A" : 
                             memo.get(p) == 2 ? "B" : "A&B");
        }
        sc.close();
    }
}
def ans(a, b, n):
    if a**b >=n:
        return -1
    
    rd = []
    dp = []
    i = 0
    while (a+i)**(b)<n:
        j = 0
        line = []
        while (a+i)**(b+j)<n:
            if a+i == 1 and (a+i+1)**(b+j)>=n:
                line.append(0)
                break
            line.append(2)
            j += 1
        if a+i != 1:
            line.append(-1)
        rd.append((i,j))
        dp.append(line)
        i += 1
        if j == 1:
            maxN = int(n**(1/b))
            if maxN**b == n:
                maxN -= 1
            if (maxN - a - i) & 1:
                dp.append([1, -1])
                break
            else:
                dp.append([-1, -1])
                break
    else:
        dp.append([-1])
    for i,end in rd[::-1]:
        for j in range(end-1, -1, -1):
            st1 = False if j >= len(dp[i+1]) - 1 else True
            st2 = False if j == end-1 else True
            if (st1 and dp[i+1][j] == -1) or (st2 and dp[i][j+1] == -1):
                dp[i][j] = 1
            elif (st1 and dp[i+1][j] == 1) and (st2 and dp[i][j+1] == 0):
                dp[i][j] = 0
            else:
                dp[i][j] = -1
    return dp[0][0]


s = ["A", "A&B", "B"]

T = int(input())
for _ in range(T):
    line = list(map(int, input().split(" ")))
    a, b, n = line[0], line[1], line[2]
    print(s[ans(a,b,n)+1])

算法及复杂度

  • 算法:记忆化搜索 + 博弈论
  • 时间复杂度: - 为测试用例数, 为状态数
  • 空间复杂度: - 为状态数