题目链接
题目描述
给定一个长度为 的数组
。我们需要构造一个非负整数
,其二进制位数不超过数组中最大值的二进制位数。然后,可以对数组
重复执行以下操作:
- 选择一个下标
和当前的整数
。
- 将
更新为
(按位或)。
- 将
更新为
(按位与)。
目标是:
- 找到使数组元素最终总和最大的操作方式。
- 在所有能达到该最大总和的初始
中,找到值最小的那个。 输出这个最大总和与最小的初始
。
解题思路
本题的核心在于理解位运算操作的本质。操作 和
有一个重要的不变量:对于任意二进制位,该位上
的总数在整个系统(即数组
和整数
)中是保持不变的。
[ \text{popcount}(A_i') + \text{popcount}(k') = \text{popcount}(A_i \lor k) + \text{popcount}(A_i \land k) = \text{popcount}(A_i) + \text{popcount}(k) ]
这意味着,对于每一个二进制位 (例如,个位、二位、四位……),我们可以独立地分析。
令
为初始数组
中第
位是
的元素个数。令
为我们选择的初始整数
的第
位(
或
)。那么在整个系统中,第
位上
的总个数为
。
无论我们执行多少次操作,这个总数不变。操作的效果,本质上是在数组元素和 之间重新分配这些值为
的位。为了使数组
的总和最大,我们应该把尽可能多的
从
“移动”到数组的元素中。
对于第 位,我们最多可以在
个数组元素中都拥有一个
。因此,经过一系列操作后,数组
中第
位为
的元素个数最多可以达到
。
数组的最终总和
可以表示为:
[ S(k) = \sum_{j} \min(n, C_j + k_j) \cdot 2^j ]
我们的目标是选择一个符合条件的 来最大化
,并在满足最大化的前提下让
最小。
-
如何最大化总和? 为了最大化
,我们需要对每一个二进制位
做出决策:
的第
位(
)应该是
还是
?
- 如果我们选择
,第
位对总和的贡献是
。
- 如果我们选择
,贡献是
。 显然,当且仅当
时,选择
才能增加总和。这种情况只在
时发生。换句话说,只要数组中存在至少一个数在第
位上是
,我们就可以通过在
的第
位放一个
来增加最终的总和。
- 如果我们选择
-
如何使
最小? 要使
最小,我们应该只在“必须”时才将
的某一位设为
。“必须”的条件就是能使总和增大的情况。
- 如果
,我们必须选择
才能达到最大和。
- 如果
(即所有数组元素的第
位都已经是
),选择
并不会增加总和。为了使
最小,此时我们应选择
。
综上,最小的、能使总和最大的
(记为
)的第
位应该这样确定:
- 若初始数组所有元素的第
位都为
,则
。
- 否则,
。
“所有元素的第
位都为
”等价于“所有元素的按位与结果的第
位为
”。令
。那么
就是
的按位取反结果。
- 如果
-
处理
的位数限制 题目要求
的二进制位数不超过数组最大值
的位数。设
的位数为
(特别地,
的位数为
),这意味着
必须小于
。我们可以构造一个掩码
,最终的
就是
。
-
计算最大总和 确定了最优的
之后,我们需要计算最终数组的总和。对于每一个二进制位
,系统中
的总数为
(其中
是原数组该位为
的个数,
是
的第
位)。为了最大化总和,这些
会被优先分配给数组中的
个位置。因此,最终数组在第
位上会有
个
。最大总和就是将所有位的贡献加起来: [ S_{\max} = \sum_{j} \min(n, C_j + k_j) \cdot 2^j ]
代码
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <cmath>
using namespace std;
void solve() {
int n;
cin >> n;
vector<long long> a(n);
long long and_all = -1; // 记录所有元素的按位与结果, -1的二进制全是1, 适合做初始值
long long max_val = 0; // 记录数组中的最大值
for (int i = 0; i < n; ++i) {
cin >> a[i];
if (and_all == -1) { // 处理第一个元素
and_all = a[i];
} else {
and_all &= a[i];
}
max_val = max(max_val, a[i]);
}
if (n == 1) { // 特殊情况:数组只有一个元素
cout << a[0] << " " << 0 << "\n";
return;
}
long long k = 0;
int bits;
if (max_val == 0) { // 特殊处理最大值为0的情况
bits = 1;
} else {
bits = floor(log2(max_val)) + 1; // 计算最大值的二进制位数
}
long long mask = (1LL << bits) - 1; // 创建一个位数限制的掩码
k = (~and_all) & mask; // 计算最优k值
long long max_sum = 0;
for (int j = 0; j < bits; ++j) { // 遍历每一位计算贡献
long long c_j = 0;
for (long long x : a) {
if ((x >> j) & 1) {
c_j++;
}
}
long long k_j = (k >> j) & 1;
long long c_j_final = min((long long)n, c_j + k_j);
max_sum += c_j_final * (1LL << j);
}
cout << max_sum << " " << k << "\n";
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int t;
cin >> t;
while (t--) {
solve();
}
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int t = sc.nextInt();
while (t-- > 0) {
solve(sc);
}
}
private static void solve(Scanner sc) {
int n = sc.nextInt();
long[] a = new long[n];
long and_all = -1; // 记录所有元素的按位与结果, -1的二进制全是1, 适合做初始值
long max_val = 0; // 记录数组中的最大值
for (int i = 0; i < n; i++) {
a[i] = sc.nextLong();
if (and_all == -1) { // 处理第一个元素
and_all = a[i];
} else {
and_all &= a[i];
}
max_val = Math.max(max_val, a[i]);
}
if (n == 1) { // 特殊情况:数组只有一个元素
System.out.println(a[0] + " " + 0);
return;
}
long k = 0;
int bits;
if (max_val == 0) { // 特殊处理最大值为0的情况
bits = 1;
} else {
// 计算最大值的二进制位数
bits = 64 - Long.numberOfLeadingZeros(max_val);
}
// 创建一个位数限制的掩码
long mask = (1L << bits) - 1;
// 计算最优k值
k = (~and_all) & mask;
long maxSum = 0;
for (int j = 0; j < bits; j++) { // 遍历每一位计算贡献
long c_j = 0;
for (long x : a) {
if (((x >> j) & 1) == 1) {
c_j++;
}
}
long k_j = (k >> j) & 1;
long c_j_final = Math.min(n, c_j + k_j);
maxSum += c_j_final * (1L << j);
}
System.out.println(maxSum + " " + k);
}
}
import math
def solve():
n = int(input())
a = list(map(int, input().split()))
if n == 1: # 特殊情况:数组只有一个元素
print(a[0], 0)
return
and_all = a[0] # 记录所有元素的按位与结果
max_val = 0 # 记录数组中的最大值
for x in a:
and_all &= x
max_val = max(max_val, x)
k = 0
if max_val == 0: # 特殊处理最大值为0的情况
bits = 1
else:
bits = max_val.bit_length() # 计算最大值的二进制位数
mask = (1 << bits) - 1 # 创建一个位数限制的掩码
k = (~and_all) & mask # 计算最优k值
max_sum = 0
for j in range(bits): # 遍历每一位计算贡献
c_j = 0
for x in a:
if (x >> j) & 1:
c_j += 1
k_j = (k >> j) & 1
c_j_final = min(n, c_j + k_j)
max_sum += c_j_final * (1 << j)
print(max_sum, k)
t = int(input())
for _ in range(t):
solve()
算法及复杂度
- 算法:位运算
- 时间复杂度:对于每组测试数据,我们首先需要
遍历数组计算
and_all
和max_val
。然后,我们需要再次遍历数组来计算每一位的贡献,这个过程的复杂度为。因此,总时间复杂度为
。
- 空间复杂度:我们只需要存储输入的数组,空间复杂度为
。