题目链接

组数制进二

题目描述

给定一个长度为 的数组 。我们需要构造一个非负整数 ,其二进制位数不超过数组中最大值的二进制位数。然后,可以对数组 重复执行以下操作:

  • 选择一个下标 和当前的整数
  • 更新为 (按位或)。
  • 更新为 (按位与)。

目标是:

  1. 找到使数组元素最终总和最大的操作方式。
  2. 在所有能达到该最大总和的初始 中,找到值最小的那个。 输出这个最大总和与最小的初始

解题思路

本题的核心在于理解位运算操作的本质。操作 有一个重要的不变量:对于任意二进制位,该位上 的总数在整个系统(即数组 和整数 )中是保持不变的。 [ \text{popcount}(A_i') + \text{popcount}(k') = \text{popcount}(A_i \lor k) + \text{popcount}(A_i \land k) = \text{popcount}(A_i) + \text{popcount}(k) ]

这意味着,对于每一个二进制位 (例如,个位、二位、四位……),我们可以独立地分析。 令 为初始数组 中第 位是 的元素个数。令 为我们选择的初始整数 的第 位()。那么在整个系统中,第 位上 的总个数为

无论我们执行多少次操作,这个总数不变。操作的效果,本质上是在数组元素和 之间重新分配这些值为 的位。为了使数组 的总和最大,我们应该把尽可能多的 “移动”到数组的元素中。

对于第 位,我们最多可以在 个数组元素中都拥有一个 。因此,经过一系列操作后,数组 中第 位为 的元素个数最多可以达到 。 数组的最终总和 可以表示为: [ S(k) = \sum_{j} \min(n, C_j + k_j) \cdot 2^j ]

我们的目标是选择一个符合条件的 来最大化 ,并在满足最大化的前提下让 最小。

  1. 如何最大化总和? 为了最大化 ,我们需要对每一个二进制位 做出决策: 的第 位()应该是 还是

    • 如果我们选择 ,第 位对总和的贡献是
    • 如果我们选择 ,贡献是 。 显然,当且仅当 时,选择 才能增加总和。这种情况只在 时发生。换句话说,只要数组中存在至少一个数在第 位上是 ,我们就可以通过在 的第 位放一个 来增加最终的总和。
  2. 如何使 最小? 要使 最小,我们应该只在“必须”时才将 的某一位设为 。“必须”的条件就是能使总和增大的情况。

    • 如果 ,我们必须选择 才能达到最大和。
    • 如果 (即所有数组元素的第 位都已经是 ),选择 并不会增加总和。为了使 最小,此时我们应选择

    综上,最小的、能使总和最大的 (记为 )的第 位应该这样确定:

    • 若初始数组所有元素的第 位都为 ,则
    • 否则,

    “所有元素的第 位都为 ”等价于“所有元素的按位与结果的第 位为 ”。令 。那么 就是 的按位取反结果。

  3. 处理 的位数限制 题目要求 的二进制位数不超过数组最大值 的位数。设 的位数为 (特别地, 的位数为 ),这意味着 必须小于 。我们可以构造一个掩码 ,最终的 就是

  4. 计算最大总和 确定了最优的 之后,我们需要计算最终数组的总和。对于每一个二进制位 ,系统中 的总数为 (其中 是原数组该位为 的个数, 的第 位)。为了最大化总和,这些 会被优先分配给数组中的 个位置。因此,最终数组在第 位上会有 。最大总和就是将所有位的贡献加起来: [ S_{\max} = \sum_{j} \min(n, C_j + k_j) \cdot 2^j ]

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <cmath>

using namespace std;

void solve() {
    int n;
    cin >> n;
    vector<long long> a(n);
    long long and_all = -1; // 记录所有元素的按位与结果, -1的二进制全是1, 适合做初始值
    long long max_val = 0;  // 记录数组中的最大值

    for (int i = 0; i < n; ++i) {
        cin >> a[i];
        if (and_all == -1) { // 处理第一个元素
            and_all = a[i];
        } else {
            and_all &= a[i];
        }
        max_val = max(max_val, a[i]);
    }

    if (n == 1) { // 特殊情况:数组只有一个元素
        cout << a[0] << " " << 0 << "\n";
        return;
    }

    long long k = 0;
    int bits;
    if (max_val == 0) { // 特殊处理最大值为0的情况
        bits = 1;
    } else {
        bits = floor(log2(max_val)) + 1; // 计算最大值的二进制位数
    }
    long long mask = (1LL << bits) - 1;  // 创建一个位数限制的掩码
    k = (~and_all) & mask;               // 计算最优k值

    long long max_sum = 0;
    for (int j = 0; j < bits; ++j) { // 遍历每一位计算贡献
        long long c_j = 0;
        for (long long x : a) {
            if ((x >> j) & 1) {
                c_j++;
            }
        }
        long long k_j = (k >> j) & 1;
        long long c_j_final = min((long long)n, c_j + k_j);
        max_sum += c_j_final * (1LL << j);
    }
    
    cout << max_sum << " " << k << "\n";
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    int t;
    cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int t = sc.nextInt();
        while (t-- > 0) {
            solve(sc);
        }
    }

    private static void solve(Scanner sc) {
        int n = sc.nextInt();
        long[] a = new long[n];
        long and_all = -1; // 记录所有元素的按位与结果, -1的二进制全是1, 适合做初始值
        long max_val = 0;  // 记录数组中的最大值

        for (int i = 0; i < n; i++) {
            a[i] = sc.nextLong();
            if (and_all == -1) { // 处理第一个元素
                and_all = a[i];
            } else {
                and_all &= a[i];
            }
            max_val = Math.max(max_val, a[i]);
        }

        if (n == 1) { // 特殊情况:数组只有一个元素
            System.out.println(a[0] + " " + 0);
            return;
        }

        long k = 0;
        int bits;
        if (max_val == 0) { // 特殊处理最大值为0的情况
            bits = 1;
        } else {
            // 计算最大值的二进制位数
            bits = 64 - Long.numberOfLeadingZeros(max_val);
        }
        // 创建一个位数限制的掩码
        long mask = (1L << bits) - 1;
        // 计算最优k值
        k = (~and_all) & mask;
        
        long maxSum = 0;
        for (int j = 0; j < bits; j++) { // 遍历每一位计算贡献
            long c_j = 0;
            for (long x : a) {
                if (((x >> j) & 1) == 1) {
                    c_j++;
                }
            }
            long k_j = (k >> j) & 1;
            long c_j_final = Math.min(n, c_j + k_j);
            maxSum += c_j_final * (1L << j);
        }
        
        System.out.println(maxSum + " " + k);
    }
}
import math

def solve():
    n = int(input())
    a = list(map(int, input().split()))

    if n == 1:  # 特殊情况:数组只有一个元素
        print(a[0], 0)
        return

    and_all = a[0]  # 记录所有元素的按位与结果
    max_val = 0     # 记录数组中的最大值
    for x in a:
        and_all &= x
        max_val = max(max_val, x)

    k = 0
    if max_val == 0: # 特殊处理最大值为0的情况
        bits = 1
    else:
        bits = max_val.bit_length() # 计算最大值的二进制位数
    
    mask = (1 << bits) - 1      # 创建一个位数限制的掩码
    k = (~and_all) & mask       # 计算最优k值

    max_sum = 0
    for j in range(bits): # 遍历每一位计算贡献
        c_j = 0
        for x in a:
            if (x >> j) & 1:
                c_j += 1
        k_j = (k >> j) & 1
        c_j_final = min(n, c_j + k_j)
        max_sum += c_j_final * (1 << j)
    
    print(max_sum, k)

t = int(input())
for _ in range(t):
    solve()

算法及复杂度

  • 算法:位运算
  • 时间复杂度:对于每组测试数据,我们首先需要 遍历数组计算 and_allmax_val。然后,我们需要再次遍历数组来计算每一位的贡献,这个过程的复杂度为 。因此,总时间复杂度为
  • 空间复杂度:我们只需要存储输入的数组,空间复杂度为