def solve():
    import sys
    import math
    import bisect

    input = sys.stdin.read
    data = input().split()
    n = int(data[0])
    a = list(map(int, data[1:]))

    if n == 1:
        print(0)
        return

    total = sum(a)

    # 情况1:所有数可以相等
    if total % n == 0:
        target = total // n
        ops = sum(abs(x - target) for x in a) // 2
        print(ops)
        return

    # 情况2:n-1个数相等
    a_sorted = sorted(a)
    best_ops = float("inf")

    # 尝试排除最小值或最大值
    candidates = [a_sorted[0], a_sorted[-1]]

    for exclude_val in candidates:
        # 找到要排除的索引
        idx = a_sorted.index(exclude_val)
        # 构建排除后的数组
        b = a_sorted[:idx] + a_sorted[idx + 1 :]
        m = len(b)  # = n-1

        # 计算排除后的总和
        T = total - exclude_val

        # 候选的众数值
        avg = T / m
        x_candidates = [math.floor(avg), math.ceil(avg)]

        for x in x_candidates:
            # 二分查找找到第一个大于等于x的位置
            pos = bisect.bisect_left(b, x)

            # 计算左右两边的差值
            # 左边:所有小于x的数需要增加到x
            left_diff = pos * x - sum(b[:pos])
            # 右边:所有大于等于x的数需要减少到x
            right_diff = sum(b[pos:]) - (m - pos) * x

            # 需要的操作次数
            ops = left_diff + right_diff

            # 计算排除的数应该变成的值
            y = total - (n - 1) * x
            # 需要加上排除数的调整次数
            ops += abs(exclude_val - y)

            best_ops = min(best_ops, ops)

    # 操作次数要除以2,因为每次操作同时调整两个数
    print(best_ops // 2)

solve()

原始代码比这个简洁点,但不够清晰,用ai把每步都分开了(它动了我的命名什么的,能过,但没细看,有什么奇奇怪怪的命名和操作请体谅)

主要思路:

  1. 最优众数值是 floor(avg) 或 ceil(avg),这个不知道具体数学证明,应该是标准差那块的,但应该挺好理解的,计算总体标准差即可
  2. 如果能所有数都相等,直接算就好;如果不能,则只需要一个缓冲,其余n-1个数都能变为一样的
  3. 被排除的数一定是最大值或最小值:因为最值距离其他数最远,将它排除后,剩余的 n-1 个数更容易变得接近

优化: 总共要算4次标准差,反正数组已经排序了,用前缀和 + 二分快速计算 diff_l 和 diff_r操作次数 = max(diff_l, diff_r) ——— 操作分为 内部抵消 和 外部缓冲,内部抵消抵消小的,缓冲抵消剩余差值,所以 总操作次数 = min(diff_l, diff_r) + |Δ| = max(diff_l, diff_r)。而对于于小于众数值的部分可以用前缀和很方便计算(diff_l = pos * x - pre[pos]),大于的部分同理。