题目链接
题目描述
给定一个数组,有两种操作:
- 删除数组的第一个元素,花费为
。
- 使数组中任意一个元素加 1 或减 1,花费为
。
目标是使用尽可能少的花费使得数组中所有剩余元素都相等。求最小总花费。
解题思路
这是一个需要结合枚举和数据结构优化的题目。总花费由“删除花费”和“修改花费”两部分构成。
1. 问题转化
我们可以枚举删除数组头部的元素个数。假设我们删除了前 个元素(
),那么:
- 删除花费为
。
- 剩下的任务是,用最小的修改花费使后缀数组
中的所有元素都相等。
2. 中位数性质
对于一个给定的数组,要将其所有元素修改为同一个目标值 ,使得总修改量(即
)最小,最优的目标值
必然是该数组的中位数。
因此,问题转化为:对于每一个后缀数组,我们需要快速计算其中位数,以及所有元素到该中位数的绝对差之和。
3. 动态维护中位数
如果对每个后缀都重新排序来找中位数,总时间复杂度会达到 ,无法通过。
一个更高效的方法是从后往前遍历原数组,依次处理
。在遍历过程中,我们动态地维护当前已遍历元素集合(即当前后缀数组)的中位数。
我们可以使用一个双堆结构来动态维护中位数:
- 一个大顶堆
small_half
,用于存储集合中较小的一半元素。 - 一个小顶堆
large_half
,用于存储集合中较大的一半元素。
我们始终保持以下两个平衡条件:
small_half
的堆顶元素小于等于large_half
的堆顶元素。small_half
的大小等于或比large_half
的大小多 1。
在这两个条件下,集合的中位数永远是 small_half
的堆顶元素。
4. 算法流程
我们从 遍历到
:
- 在第
步,我们将
加入双堆结构中,并进行调整以维持平衡。
- 同时,我们维护两个堆中元素的总和
sum_small
和sum_large
。 - 此时,双堆结构中包含了后缀数组
的所有元素。
- 计算当前总花费:
- 删除花费:
- 修改花费:设中位数为
(
small_half.top()
)。所有元素到中位数的绝对差之和为(M * small_half.size() - sum_small) + (sum_large - M * large_half.size())
。再乘以单次修改花费。
- 删除花费:
- 将删除花费和修改花费相加,更新全局最小总花费。
5. 注意事项
花费的计算结果可能非常大,会超出标准 64 位整型的范围。在 C++ 中需要使用 __int128
来存储总花费,以防止溢出。Python 的整数类型则无此顾虑。
代码
#include <bits/stdc++.h>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
long long x, y;
cin >> n >> x >> y;
vector<long long> a(n);
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
priority_queue<long long> small_half;
priority_queue<long long, vector<long long>, greater<long long>> large_half;
__int128 sum_small = 0, sum_large = 0;
__int128 min_total_cost = -1;
for (int i = n - 1; i >= 0; --i) {
// Add a[i] to the heaps
if (small_half.empty() || a[i] <= small_half.top()) {
small_half.push(a[i]);
sum_small += a[i];
} else {
large_half.push(a[i]);
sum_large += a[i];
}
// Rebalance the heaps
if (small_half.size() > large_half.size() + 1) {
large_half.push(small_half.top());
sum_large += small_half.top();
sum_small -= small_half.top();
small_half.pop();
} else if (large_half.size() > small_half.size()) {
small_half.push(large_half.top());
sum_small += large_half.top();
sum_large -= large_half.top();
large_half.pop();
}
long long median = small_half.top();
__int128 modify_dist = (median * (__int128)small_half.size() - sum_small) +
(sum_large - median * (__int128)large_half.size());
__int128 current_cost = (__int128)i * x + modify_dist * y;
if (min_total_cost == -1 || current_cost < min_total_cost) {
min_total_cost = current_cost;
}
}
// Since __int128 cannot be printed directly with cout
string result = "";
if (min_total_cost == 0) {
result = "0";
} else {
while (min_total_cost > 0) {
result = to_string((int)(min_total_cost % 10)) + result;
min_total_cost /= 10;
}
}
cout << result << endl;
return 0;
}
import java.util.Collections;
import java.util.PriorityQueue;
import java.util.Scanner;
import java.math.BigInteger;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
long x = sc.nextLong();
long y = sc.nextLong();
long[] a = new long[n];
for (int i = 0; i < n; i++) {
a[i] = sc.nextLong();
}
PriorityQueue<Long> small_half = new PriorityQueue<>(Collections.reverseOrder());
PriorityQueue<Long> large_half = new PriorityQueue<>();
BigInteger sum_small = BigInteger.ZERO;
BigInteger sum_large = BigInteger.ZERO;
BigInteger min_total_cost = null;
BigInteger bigX = BigInteger.valueOf(x);
BigInteger bigY = BigInteger.valueOf(y);
for (int i = n - 1; i >= 0; i--) {
if (small_half.isEmpty() || a[i] <= small_half.peek()) {
small_half.add(a[i]);
sum_small = sum_small.add(BigInteger.valueOf(a[i]));
} else {
large_half.add(a[i]);
sum_large = sum_large.add(BigInteger.valueOf(a[i]));
}
if (small_half.size() > large_half.size() + 1) {
long val = small_half.poll();
large_half.add(val);
sum_small = sum_small.subtract(BigInteger.valueOf(val));
sum_large = sum_large.add(BigInteger.valueOf(val));
} else if (large_half.size() > small_half.size()) {
long val = large_half.poll();
small_half.add(val);
sum_large = sum_large.subtract(BigInteger.valueOf(val));
sum_small = sum_small.add(BigInteger.valueOf(val));
}
long median = small_half.peek();
BigInteger bigMedian = BigInteger.valueOf(median);
BigInteger modify_dist = bigMedian.multiply(BigInteger.valueOf(small_half.size()))
.subtract(sum_small)
.add(sum_large.subtract(bigMedian.multiply(BigInteger.valueOf(large_half.size()))));
BigInteger current_cost = bigX.multiply(BigInteger.valueOf(i))
.add(modify_dist.multiply(bigY));
if (min_total_cost == null || current_cost.compareTo(min_total_cost) < 0) {
min_total_cost = current_cost;
}
}
System.out.println(min_total_cost.toString());
}
}
import heapq
import sys
def solve():
n, x, y = map(int, sys.stdin.readline().split())
a = list(map(int, sys.stdin.readline().split()))
small_half = [] # Max-heap, so we store negative values
large_half = [] # Min-heap
sum_small = 0
sum_large = 0
min_total_cost = -1
for i in range(n - 1, -1, -1):
# Add a[i]
if not small_half or a[i] <= -small_half[0]:
heapq.heappush(small_half, -a[i])
sum_small += a[i]
else:
heapq.heappush(large_half, a[i])
sum_large += a[i]
# Rebalance
if len(small_half) > len(large_half) + 1:
val = -heapq.heappop(small_half)
heapq.heappush(large_half, val)
sum_small -= val
sum_large += val
elif len(large_half) > len(small_half):
val = heapq.heappop(large_half)
heapq.heappush(small_half, -val)
sum_large -= val
sum_small += val
median = -small_half[0]
modify_dist = (median * len(small_half) - sum_small) + (sum_large - median * len(large_half))
current_cost = i * x + modify_dist * y
if min_total_cost == -1 or current_cost < min_total_cost:
min_total_cost = current_cost
print(min_total_cost)
solve()
算法及复杂度
- 算法:枚举 + 动态中位数 (双堆)
- 时间复杂度:
,其中
是数组的大小。我们遍历数组一次,每次向堆中插入元素的操作需要
的时间。
- 空间复杂度:
,用于存储两个堆中的元素。