小红的相等数组

[题目链接](https://www.nowcoder.com/practice/d45de1dae26a49bca9d8832b9551737f)

思路

构造长度为 的数组,每个元素在 ,要求所有元素的异或和 不超过与和 。统计方案数,对 取模。

逐位分析

异或和与按位与都是逐位运算。对于第 位,设 个元素中该位为 1 的个数,则:

  • AND 第 当且仅当
  • XOR 第 当且仅当 为奇数

要满足 ,从最高位到最低位扫描,第一个 不同的位必须是 (即 在该位赢了)。

每一位有三种状态:

状态 条件 含义
相等 (E) 为奇数,或 为偶数
A 赢 (A) 为偶数
X 赢 (X) 为奇数

每一位的 取值对应 种数组分配方式。定义:

$$

利用 ,可得:

为奇数时

由于没有位能让 赢,唯一满足条件的情况是所有位都相等(),答案为

为偶数时

计算答案

个位各自独立选择状态。合法方案要求从高位到低位,第一个非 E 位必须是 A。枚举第一个非 E 位的位置 ),前 位全为 E,第 位为 A,后面 位任意(共 种选择):

$$

其中 对应全部位都相等的情况。由于 ,利用等比数列求和:

$$

分母 用费马小定理求逆元即可。

复杂度分析

  • 时间复杂度:,仅需若干次快速幂。
  • 空间复杂度:

代码

#include <iostream>
using namespace std;

const long long MOD = 1e9 + 7;

long long power(long long base, long long exp, long long mod) {
    long long result = 1;
    base %= mod;
    while (exp > 0) {
        if (exp & 1) result = result * base % mod;
        base = base * base % mod;
        exp >>= 1;
    }
    return result;
}

int main() {
    long long n, k;
    cin >> n >> k;

    long long T = power(2, n, MOD);

    if (n % 2 == 1) {
        long long half = power(2, n - 1, MOD);
        long long fE = (1 + half) % MOD;
        cout << power(fE, k, MOD) << endl;
    } else {
        long long half = power(2, n - 1, MOD);
        long long fE = (half - 1 + MOD) % MOD;
        long long Tk = power(T, k, MOD);
        long long fEk = power(fE, k, MOD);
        long long denom = (half + 1) % MOD;
        long long num = (Tk - fEk + MOD) % MOD;
        long long inv_denom = power(denom, MOD - 2, MOD);
        long long ans = (fEk + num % MOD * inv_denom % MOD) % MOD;
        cout << ans << endl;
    }
    return 0;
}
import java.util.Scanner;

public class Main {
    static final long MOD = 1_000_000_007;

    static long power(long base, long exp, long mod) {
        long result = 1;
        base %= mod;
        while (exp > 0) {
            if ((exp & 1) == 1) result = result * base % mod;
            base = base * base % mod;
            exp >>= 1;
        }
        return result;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long n = sc.nextLong();
        long k = sc.nextLong();

        long T = power(2, n, MOD);

        if (n % 2 == 1) {
            long half = power(2, n - 1, MOD);
            long fE = (1 + half) % MOD;
            System.out.println(power(fE, k, MOD));
        } else {
            long half = power(2, n - 1, MOD);
            long fE = (half - 1 + MOD) % MOD;
            long Tk = power(T, k, MOD);
            long fEk = power(fE, k, MOD);
            long denom = (half + 1) % MOD;
            long num = (Tk - fEk + MOD) % MOD;
            long invDenom = power(denom, MOD - 2, MOD);
            long ans = (fEk + num % MOD * invDenom % MOD) % MOD;
            System.out.println(ans);
        }
    }
}