小红的数组操作

[题目链接](https://www.nowcoder.com/practice/bedca3d0e77644478f1ceb93a9f7889e)

思路

给定长度为 的数组 ,两种操作:

  1. 删除数组的第一个元素,花费
  2. 将任意一个元素加 1 或减 1,花费

求使所有剩余元素相等的最小总花费。

枚举前缀删除长度

我们可以删除前 个元素(),剩余子数组为 。删除花费为 ,还需要用操作 2 将剩余元素全部变成同一个值,花费为 ,其中 为目标值。

众所周知,使绝对值偏差之和最小的目标值 就是中位数。因此对每个后缀 ,操作 2 的最小花费为 乘以该后缀所有元素到中位数的绝对偏差之和。

从右向左插入 + 对顶多重集维护中位数

遍历,每次将 插入数据结构。用对顶多重集(lower half 和 upper half)维护中位数和偏差之和:

  • lo:存储较小的一半元素,hi:存储较大的一半元素。
  • 保持 ,此时中位数为 lo 的最大值。
  • 同时维护 sumLosumHi,即两半的元素之和。

设中位数为 ,则绝对偏差之和为:

$$

每次插入后,总花费为 ,取所有 的最小值即为答案。

溢出处理

注意 最大可达 ,乘积可达 ,超过 64 位整数范围。C++ 中使用 __int128,Java 中使用 BigInteger 处理。

样例演示

输入 ,数组

删除前 剩余数组 中位数 偏差之和 总花费
0 3 12 12
1 2.5→取 2 或 3 4 6
2 3 2 6
3 2 1 7
4 3 0 8

最小花费为

复杂度分析

  • 时间复杂度:,每次插入和平衡操作为
  • 空间复杂度:,多重集存储所有元素。

代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    long long x, y;
    cin >> n >> x >> y;
    vector<long long> a(n);
    for(int i = 0; i < n; i++) cin >> a[i];

    multiset<long long> lo, hi;
    long long sumLo = 0, sumHi = 0;
    long long ans = (long long)(n - 1) * x;

    auto balance = [&](){
        while((int)lo.size() > (int)hi.size() + 1){
            auto it = prev(lo.end());
            long long v = *it;
            lo.erase(it);
            sumLo -= v;
            hi.insert(v);
            sumHi += v;
        }
        while((int)hi.size() > (int)lo.size()){
            auto it = hi.begin();
            long long v = *it;
            hi.erase(it);
            sumHi -= v;
            lo.insert(v);
            sumLo += v;
        }
    };

    for(int i = n - 1; i >= 0; i--){
        if(lo.empty() || a[i] <= *lo.rbegin()){
            lo.insert(a[i]);
            sumLo += a[i];
        } else {
            hi.insert(a[i]);
            sumHi += a[i];
        }
        balance();

        long long med = *lo.rbegin();
        long long cntL = (long long)lo.size();
        long long cntH = (long long)hi.size();
        long long cost = med * cntL - sumLo + sumHi - med * cntH;
        long long deleteCost = (long long)i * x;

        __int128 total = (__int128)deleteCost + (__int128)y * cost;
        if(total < (__int128)ans){
            ans = (long long)total;
        }
    }

    cout << ans << "\n";
    return 0;
}
import java.util.*;
import java.io.*;
import java.math.BigInteger;

public class Main {
    static TreeMap<Long, Integer> lo = new TreeMap<>();
    static TreeMap<Long, Integer> hi = new TreeMap<>();
    static long sumLo = 0, sumHi = 0;
    static int cntLo = 0, cntHi = 0;

    static void addTo(TreeMap<Long, Integer> map, long v) {
        map.merge(v, 1, Integer::sum);
    }

    static void removeFrom(TreeMap<Long, Integer> map, long v) {
        int c = map.get(v);
        if (c == 1) map.remove(v);
        else map.put(v, c - 1);
    }

    static void insertVal(long v) {
        if (cntLo == 0 || v <= lo.lastKey()) {
            addTo(lo, v);
            sumLo += v;
            cntLo++;
        } else {
            addTo(hi, v);
            sumHi += v;
            cntHi++;
        }
        balance();
    }

    static void balance() {
        while (cntLo > cntHi + 1) {
            long v = lo.lastKey();
            removeFrom(lo, v);
            sumLo -= v;
            cntLo--;
            addTo(hi, v);
            sumHi += v;
            cntHi++;
        }
        while (cntHi > cntLo) {
            long v = hi.firstKey();
            removeFrom(hi, v);
            sumHi -= v;
            cntHi--;
            addTo(lo, v);
            sumLo += v;
            cntLo++;
        }
    }

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        long x = Long.parseLong(st.nextToken());
        long y = Long.parseLong(st.nextToken());
        st = new StringTokenizer(br.readLine());
        long[] a = new long[n];
        for (int i = 0; i < n; i++) a[i] = Long.parseLong(st.nextToken());

        BigInteger ans = BigInteger.valueOf(n - 1).multiply(BigInteger.valueOf(x));

        for (int i = n - 1; i >= 0; i--) {
            insertVal(a[i]);

            long med = lo.lastKey();
            long cost = med * (long) cntLo - sumLo + sumHi - med * (long) cntHi;
            BigInteger total = BigInteger.valueOf(i).multiply(BigInteger.valueOf(x))
                    .add(BigInteger.valueOf(y).multiply(BigInteger.valueOf(cost)));
            if (total.compareTo(ans) < 0) {
                ans = total;
            }
        }

        System.out.println(ans);
    }
}