题目链接

小红的数组操作

题目描述

给定一个数组,有两种操作:

  1. 删除数组的第一个元素,花费为
  2. 使数组中任意一个元素加 1 或减 1,花费为

目标是使用尽可能少的花费使得数组中所有剩余元素都相等。求最小总花费。

解题思路

这是一个需要结合枚举和数据结构优化的题目。总花费由“删除花费”和“修改花费”两部分构成。

1. 问题转化 我们可以枚举删除数组头部的元素个数。假设我们删除了前 个元素(),那么:

  • 删除花费为
  • 剩下的任务是,用最小的修改花费使后缀数组 中的所有元素都相等。

2. 中位数性质 对于一个给定的数组,要将其所有元素修改为同一个目标值 ,使得总修改量(即 )最小,最优的目标值 必然是该数组的中位数

因此,问题转化为:对于每一个后缀数组,我们需要快速计算其中位数,以及所有元素到该中位数的绝对差之和。

3. 动态维护中位数 如果对每个后缀都重新排序来找中位数,总时间复杂度会达到 ,无法通过。 一个更高效的方法是从后往前遍历原数组,依次处理 。在遍历过程中,我们动态地维护当前已遍历元素集合(即当前后缀数组)的中位数。

我们可以使用一个双堆结构来动态维护中位数:

  • 一个大顶堆 small_half,用于存储集合中较小的一半元素。
  • 一个小顶堆 large_half,用于存储集合中较大的一半元素。

我们始终保持以下两个平衡条件:

  1. small_half 的堆顶元素小于等于 large_half 的堆顶元素。
  2. small_half 的大小等于或比 large_half 的大小多 1。

在这两个条件下,集合的中位数永远是 small_half 的堆顶元素。

4. 算法流程 我们从 遍历到

  1. 在第 步,我们将 加入双堆结构中,并进行调整以维持平衡。
  2. 同时,我们维护两个堆中元素的总和 sum_smallsum_large
  3. 此时,双堆结构中包含了后缀数组 的所有元素。
  4. 计算当前总花费:
    • 删除花费:
    • 修改花费:设中位数为 (small_half.top())。所有元素到中位数的绝对差之和为 (M * small_half.size() - sum_small) + (sum_large - M * large_half.size())。再乘以单次修改花费
  5. 将删除花费和修改花费相加,更新全局最小总花费。

5. 注意事项 花费的计算结果可能非常大,会超出标准 64 位整型的范围。在 C++ 中需要使用 __int128 来存储总花费,以防止溢出。Python 的整数类型则无此顾虑。

代码

#include <bits/stdc++.h>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    long long x, y;
    cin >> n >> x >> y;
    vector<long long> a(n);
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
    }

    priority_queue<long long> small_half;
    priority_queue<long long, vector<long long>, greater<long long>> large_half;
    __int128 sum_small = 0, sum_large = 0;
    __int128 min_total_cost = -1;

    for (int i = n - 1; i >= 0; --i) {
        // Add a[i] to the heaps
        if (small_half.empty() || a[i] <= small_half.top()) {
            small_half.push(a[i]);
            sum_small += a[i];
        } else {
            large_half.push(a[i]);
            sum_large += a[i];
        }

        // Rebalance the heaps
        if (small_half.size() > large_half.size() + 1) {
            large_half.push(small_half.top());
            sum_large += small_half.top();
            sum_small -= small_half.top();
            small_half.pop();
        } else if (large_half.size() > small_half.size()) {
            small_half.push(large_half.top());
            sum_small += large_half.top();
            sum_large -= large_half.top();
            large_half.pop();
        }

        long long median = small_half.top();
        __int128 modify_dist = (median * (__int128)small_half.size() - sum_small) + 
                               (sum_large - median * (__int128)large_half.size());
        __int128 current_cost = (__int128)i * x + modify_dist * y;

        if (min_total_cost == -1 || current_cost < min_total_cost) {
            min_total_cost = current_cost;
        }
    }

    // Since __int128 cannot be printed directly with cout
    string result = "";
    if (min_total_cost == 0) {
        result = "0";
    } else {
        while (min_total_cost > 0) {
            result = to_string((int)(min_total_cost % 10)) + result;
            min_total_cost /= 10;
        }
    }
    cout << result << endl;

    return 0;
}
import java.util.Collections;
import java.util.PriorityQueue;
import java.util.Scanner;
import java.math.BigInteger;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        long x = sc.nextLong();
        long y = sc.nextLong();
        long[] a = new long[n];
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextLong();
        }

        PriorityQueue<Long> small_half = new PriorityQueue<>(Collections.reverseOrder());
        PriorityQueue<Long> large_half = new PriorityQueue<>();
        BigInteger sum_small = BigInteger.ZERO;
        BigInteger sum_large = BigInteger.ZERO;
        BigInteger min_total_cost = null;

        BigInteger bigX = BigInteger.valueOf(x);
        BigInteger bigY = BigInteger.valueOf(y);

        for (int i = n - 1; i >= 0; i--) {
            if (small_half.isEmpty() || a[i] <= small_half.peek()) {
                small_half.add(a[i]);
                sum_small = sum_small.add(BigInteger.valueOf(a[i]));
            } else {
                large_half.add(a[i]);
                sum_large = sum_large.add(BigInteger.valueOf(a[i]));
            }

            if (small_half.size() > large_half.size() + 1) {
                long val = small_half.poll();
                large_half.add(val);
                sum_small = sum_small.subtract(BigInteger.valueOf(val));
                sum_large = sum_large.add(BigInteger.valueOf(val));
            } else if (large_half.size() > small_half.size()) {
                long val = large_half.poll();
                small_half.add(val);
                sum_large = sum_large.subtract(BigInteger.valueOf(val));
                sum_small = sum_small.add(BigInteger.valueOf(val));
            }

            long median = small_half.peek();
            BigInteger bigMedian = BigInteger.valueOf(median);

            BigInteger modify_dist = bigMedian.multiply(BigInteger.valueOf(small_half.size()))
                                             .subtract(sum_small)
                                             .add(sum_large.subtract(bigMedian.multiply(BigInteger.valueOf(large_half.size()))));

            BigInteger current_cost = bigX.multiply(BigInteger.valueOf(i))
                                        .add(modify_dist.multiply(bigY));

            if (min_total_cost == null || current_cost.compareTo(min_total_cost) < 0) {
                min_total_cost = current_cost;
            }
        }
        System.out.println(min_total_cost.toString());
    }
}
import heapq
import sys

def solve():
    n, x, y = map(int, sys.stdin.readline().split())
    a = list(map(int, sys.stdin.readline().split()))

    small_half = []  # Max-heap, so we store negative values
    large_half = []  # Min-heap
    sum_small = 0
    sum_large = 0
    min_total_cost = -1

    for i in range(n - 1, -1, -1):
        # Add a[i]
        if not small_half or a[i] <= -small_half[0]:
            heapq.heappush(small_half, -a[i])
            sum_small += a[i]
        else:
            heapq.heappush(large_half, a[i])
            sum_large += a[i]
        
        # Rebalance
        if len(small_half) > len(large_half) + 1:
            val = -heapq.heappop(small_half)
            heapq.heappush(large_half, val)
            sum_small -= val
            sum_large += val
        elif len(large_half) > len(small_half):
            val = heapq.heappop(large_half)
            heapq.heappush(small_half, -val)
            sum_large -= val
            sum_small += val

        median = -small_half[0]
        
        modify_dist = (median * len(small_half) - sum_small) + (sum_large - median * len(large_half))
        current_cost = i * x + modify_dist * y

        if min_total_cost == -1 or current_cost < min_total_cost:
            min_total_cost = current_cost
    
    print(min_total_cost)

solve()

算法及复杂度

  • 算法:枚举 + 动态中位数 (双堆)
  • 时间复杂度:,其中 是数组的大小。我们遍历数组一次,每次向堆中插入元素的操作需要 的时间。
  • 空间复杂度:,用于存储两个堆中的元素。