题目链接

REAL740 小红的相等数组

题目描述

小红希望你构造一个长度为 的数组,满足:

  1. 数组中的每个元素 满足
  2. 数组所有元素的异或和小于等于所有元素的与和。即:

小红想知道有多少种可能的方案数。答案对 取模。

思路分析

1. 问题转换:按位分析

问题的核心约束是 异或和 <= 与和。这种涉及位运算的大小比较问题,通常可以从高位到低位逐位确定,这正是动态规划(DP),特别是数位DP (Digit DP) 的典型应用场景。

我们从高到低(从第 位到第 位)来考虑数组中所有 个元素的每一位。

XOR_sum 的第 位为 AND_sum 的第 位为 XOR_sum <= AND_sum 这个条件成立,当且仅当:

  • XOR_sum == AND_sum,即对于所有位
  • 或者,存在一个最高位 ,使得 ,并且对于所有比 更高的位 ,都有 。一旦 在高位确定下来,低位无论如何取值,最终的大小关系都不会改变。

这个结构天然适合 DP。我们定义状态来记录从高位到当前位的比较结果。

2. DP 状态与转移

我们定义 DP 状态 dp[i][state] 表示考虑了从高到低 位(即第 位到第 位)之后,满足特定状态的方案数。

  • dp[i][0]:前 位满足 XOR_sum 的高位部分 等于 AND_sum 的高位部分的方案数。
  • dp[i][1]:前 位满足 XOR_sum 的高位部分 小于 AND_sum 的高位部分的方案数。

DP 过程:我们从 dp[0][0]=1, dp[0][1]=0 (处理0位,高位部分为空,视作相等) 开始,递推计算到 dp[k]

对于当前位 ,我们有多少种选择方案(即 的当前位如何取值)可以使状态从 equal 变为 equal,从 equal 变为 less,或者从 less 保持 less

设当前位有 个1。

  • 当且仅当 (所有元素的第 位都是1)。
  • 当且仅当 是奇数。

状态转移系数

  1. ways_equal:使 的方案数。
    • 是偶数且
    • 是奇数且 (这要求 本身是奇数)。
  2. ways_less:使 的方案数。
    • 是偶数且 (这要求 本身是偶数)。
  3. ways_any: 任意选择,总共有 种方案。

利用组合数学恒等式 ,我们可以快速计算出 ways_equalways_less

  • 为偶数
    • ways_equal = (选偶数个1, 但不选n个) =
    • ways_less = (选n个1, n是偶数) =
  • 为奇数
    • ways_equal = (选偶数个1) + (选n个1, n是奇数) =
    • ways_less = 0 (因为 是奇数,不可能满足 是偶数)。

DP 递推式 (从 i = 1k):

  • dp[i][0] = (dp[i-1][0] * ways_equal) % M
  • dp[i][1] = (dp[i-1][0] * ways_less + dp[i-1][1] * 2^n) % M

最终答案为 (dp[k][0] + dp[k][1]) % M

代码

#include <iostream>
#include <vector>

using namespace std;

long long power(long long base, long long exp) {
    long long res = 1;
    base %= 1000000007;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % 1000000007;
        base = (base * base) % 1000000007;
        exp /= 2;
    }
    return res;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, k;
    cin >> n >> k;

    long long M = 1000000007;

    long long ways_equal, ways_less;
    long long p2_n_minus_1 = power(2, n - 1);

    if (n % 2 == 0) {
        ways_equal = (p2_n_minus_1 - 1 + M) % M;
        ways_less = 1;
    } else {
        ways_equal = (p2_n_minus_1 + 1) % M;
        ways_less = 0;
    }

    long long p2_n = power(2, n);
    long long dp_equal = 1;
    long long dp_less = 0;

    for (int i = 0; i < k; ++i) {
        long long next_dp_equal = (dp_equal * ways_equal) % M;
        long long next_dp_less = (dp_equal * ways_less + dp_less * p2_n) % M;
        dp_equal = next_dp_equal;
        dp_less = next_dp_less;
    }

    cout << (dp_equal + dp_less) % M << endl;

    return 0;
}
import java.util.Scanner;

public class Main {
    static long M = 1_000_000_007;

    public static long power(long base, long exp) {
        long res = 1;
        base %= M;
        while (exp > 0) {
            if (exp % 2 == 1) {
                res = (res * base) % M;
            }
            base = (base * base) % M;
            exp /= 2;
        }
        return res;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int k = sc.nextInt();

        long waysEqual, waysLess;
        long p2nMinus1 = power(2, n - 1);

        if (n % 2 == 0) {
            waysEqual = (p2nMinus1 - 1 + M) % M;
            waysLess = 1;
        } else {
            waysEqual = (p2nMinus1 + 1) % M;
            waysLess = 0;
        }

        long p2n = power(2, n);
        long dpEqual = 1;
        long dpLess = 0;

        for (int i = 0; i < k; i++) {
            long nextDpEqual = (dpEqual * waysEqual) % M;
            long nextDpLess = (dpEqual * waysLess + dpLess * p2n) % M;
            dpEqual = nextDpEqual;
            dpLess = nextDpLess;
        }

        System.out.println((dpEqual + dpLess) % M);
    }
}
import sys

def power(base, exp, mod):
    res = 1
    base %= mod
    while exp > 0:
        if exp % 2 == 1:
            res = (res * base) % mod
        base = (base * base) % mod
        exp //= 2
    return res

def solve():
    n, k = map(int, sys.stdin.readline().split())
    M = 10**9 + 7

    p2_n_minus_1 = power(2, n - 1, M)

    if n % 2 == 0:
        ways_equal = (p2_n_minus_1 - 1 + M) % M
        ways_less = 1
    else:
        ways_equal = (p2_n_minus_1 + 1) % M
        ways_less = 0
    
    p2_n = power(2, n, M)
    
    dp_equal = 1
    dp_less = 0
    
    for _ in range(k):
        next_dp_equal = (dp_equal * ways_equal) % M
        next_dp_less = (dp_equal * ways_less + dp_less * p2_n) % M
        dp_equal = next_dp_equal
        dp_less = next_dp_less
        
    print((dp_equal + dp_less) % M)

solve()

算法及复杂度

  • 算法:动态规划(按位DP)
  • 时间复杂度。DP 循环进行 次,每次内部有常数次模乘。预计算转移系数时需要计算模幂,其复杂度为 。因此总复杂度为
  • 空间复杂度。DP 状态的转移只需要常数个变量。