def solve():
import sys
import math
import bisect
input = sys.stdin.read
data = input().split()
n = int(data[0])
a = list(map(int, data[1:]))
if n == 1:
print(0)
return
total = sum(a)
# 情况1:所有数可以相等
if total % n == 0:
target = total // n
ops = sum(abs(x - target) for x in a) // 2
print(ops)
return
# 情况2:n-1个数相等
a_sorted = sorted(a)
best_ops = float("inf")
# 尝试排除最小值或最大值
candidates = [a_sorted[0], a_sorted[-1]]
for exclude_val in candidates:
# 找到要排除的索引
idx = a_sorted.index(exclude_val)
# 构建排除后的数组
b = a_sorted[:idx] + a_sorted[idx + 1 :]
m = len(b) # = n-1
# 计算排除后的总和
T = total - exclude_val
# 候选的众数值
avg = T / m
x_candidates = [math.floor(avg), math.ceil(avg)]
for x in x_candidates:
# 二分查找找到第一个大于等于x的位置
pos = bisect.bisect_left(b, x)
# 计算左右两边的差值
# 左边:所有小于x的数需要增加到x
left_diff = pos * x - sum(b[:pos])
# 右边:所有大于等于x的数需要减少到x
right_diff = sum(b[pos:]) - (m - pos) * x
# 需要的操作次数
ops = left_diff + right_diff
# 计算排除的数应该变成的值
y = total - (n - 1) * x
# 需要加上排除数的调整次数
ops += abs(exclude_val - y)
best_ops = min(best_ops, ops)
# 操作次数要除以2,因为每次操作同时调整两个数
print(best_ops // 2)
solve()
原始代码比这个简洁点,但不够清晰,用ai把每步都分开了(它动了我的命名什么的,能过,但没细看,有什么奇奇怪怪的命名和操作请体谅)
主要思路:
- 最优众数值是 floor(avg) 或 ceil(avg),这个不知道具体数学证明,应该是标准差那块的,但应该挺好理解的,计算总体标准差即可
- 如果能所有数都相等,直接算就好;如果不能,则只需要一个缓冲,其余n-1个数都能变为一样的
- 被排除的数一定是最大值或最小值:因为最值距离其他数最远,将它排除后,剩余的 n-1 个数更容易变得接近
优化: 总共要算4次标准差,反正数组已经排序了,用前缀和 + 二分快速计算 diff_l 和 diff_r操作次数 = max(diff_l, diff_r) ——— 操作分为 内部抵消 和 外部缓冲,内部抵消抵消小的,缓冲抵消剩余差值,所以 总操作次数 = min(diff_l, diff_r) + |Δ| = max(diff_l, diff_r)。而对于于小于众数值的部分可以用前缀和很方便计算(diff_l = pos * x - pre[pos]),大于的部分同理。

京公网安备 11010502036488号