题目链接
题目描述
小红拿到了一个数组,她准备对这个数组进行如下两种操作:
- 删除数组的第一个元素。该操作的花费为
。
- 使得数组中任意一个元素加 1 或者减 1。该操作的花费为
。
小红希望用尽可能少的花费使得数组所有元素都相等,请你帮小红求出最小的花费。
输入:
- 第一行输入三个正整数
、
、
,代表数组的大小以及两种操作的花费
- 第二行输入
个正整数,代表数组的元素
输出:
- 输出一个整数,代表将所有元素变成相等的最小总花费
解题思路
这是一个动态查询问题,可以通过以下步骤解决:
-
关键发现:
- 对于每个子数组,最优的目标值是中位数
- 需要动态维护元素个数和元素和
- 需要能快速查询第k大的元素
-
解题策略:
- 使用树状数组维护元素个数和元素和
- 离散化处理数组元素
- 使用二分查找第k大的元素
-
具体步骤:
- 对数组元素进行离散化处理
- 使用两个树状数组分别维护个数和元素和
- 依次删除前缀,计算每种情况的代价
代码
#include <iostream>
#include <vector>
#include <map>
using namespace std;
using ll=long long;
const int N=100100;
int a[N];
int n;
ll x,y;
struct BIT{
vector<ll> tr1;
int n;
BIT(int n=200100):n(n),tr1(n+5){}
void add(int now,ll val){
for(int i=now;i<=n;i+=i&-i){
tr1[i]+=val;
}
}
ll query(int now){
ll res=0;
for(int i=now;i>0;i-=i&-i)
res+=tr1[i];
return res;
}
ll query(int l,int r){
return query(r)-query(l-1);
}
int kth(int k){
int ans=0,res=0;
for(int i=1<<__lg(n);i>0;i>>=1){
ans+=i;
if(ans<n&&res+tr1[ans]<k)
res+=tr1[ans];
else
ans-=i;
}
return ans+1;
}
};
map<int,int> ls;
int sl[N];
int ls_cnt=0;
void init()
{
for(int i=1;i<=n;++i)
ls[a[i]]=0;
for(auto &p:ls)
{
p.second=++ls_cnt;
sl[ls_cnt]=p.first;
}
}
int main(void)
{
ios::sync_with_stdio(false);
cin.tie(0);
__int128_t ans=9e18,r1,r2,r3;
int i,idx;
cin>>n>>x>>y;
for(i=1;i<=n;++i)
cin>>a[i];
init();
BIT cnt(n+5),num(n+5);
for(i=1;i<=n;++i)
{
cnt.add(ls[a[i]],1);
num.add(ls[a[i]],a[i]);
}
for(i=1;i<=n;++i)
{
idx=cnt.kth((n-i)/2+1);
r1=cnt.query(1,idx)*sl[idx]-num.query(1,idx);
r2=num.query(idx,n)-cnt.query(idx,n)*sl[idx];
r3=i-1;
ans=min(ans,r1*y+r2*y+r3*x);
cnt.add(ls[a[i]],-1);
num.add(ls[a[i]],-a[i]);
}
cout<<(ll)ans;
return 0;
}
import java.util.*;
import java.io.*;
import java.math.BigInteger;
public class Main {
static class BIT {
long[] tree;
int n;
BIT(int n) {
this.n = n;
tree = new long[n + 5];
}
void add(int pos, long val) {
for(int i = pos; i <= n; i += i & -i) {
tree[i] += val;
}
}
long query(int pos) {
long res = 0;
for(int i = pos; i > 0; i -= i & -i) {
res += tree[i];
}
return res;
}
long query(int l, int r) {
return query(r) - query(l - 1);
}
int kth(int k) {
int ans = 0, res = 0;
for(int i = 1 << (31 - Integer.numberOfLeadingZeros(n)); i > 0; i >>= 1) {
ans += i;
if(ans < n && res + tree[ans] < k) {
res += tree[ans];
} else {
ans -= i;
}
}
return ans + 1;
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String[] line = br.readLine().split(" ");
int n = Integer.parseInt(line[0]);
BigInteger x = new BigInteger(line[1]);
BigInteger y = new BigInteger(line[2]);
int[] a = new int[n + 1];
line = br.readLine().split(" ");
TreeMap<Integer, Integer> map = new TreeMap<>();
for(int i = 1; i <= n; i++) {
a[i] = Integer.parseInt(line[i - 1]);
map.put(a[i], 0);
}
int idx = 1;
int[] sl = new int[n + 1];
for(Map.Entry<Integer, Integer> entry : map.entrySet()) {
entry.setValue(idx);
sl[idx] = entry.getKey();
idx++;
}
BIT cnt = new BIT(n + 5);
BIT num = new BIT(n + 5);
for(int i = 1; i <= n; i++) {
cnt.add(map.get(a[i]), 1);
num.add(map.get(a[i]), a[i]);
}
BigInteger ans = null;
for(int i = 1; i <= n; i++) {
int mid = cnt.kth((n - i) / 2 + 1);
BigInteger r1 = BigInteger.valueOf(cnt.query(1, mid))
.multiply(BigInteger.valueOf(sl[mid]))
.subtract(BigInteger.valueOf(num.query(1, mid)));
BigInteger r2 = BigInteger.valueOf(num.query(mid, n))
.subtract(BigInteger.valueOf(cnt.query(mid, n))
.multiply(BigInteger.valueOf(sl[mid])));
BigInteger r3 = BigInteger.valueOf(i - 1);
BigInteger curr = r1.multiply(y).add(r2.multiply(y)).add(r3.multiply(x));
if(ans == null || curr.compareTo(ans) < 0) {
ans = curr;
}
cnt.add(map.get(a[i]), -1);
num.add(map.get(a[i]), -a[i]);
}
System.out.println(ans);
}
}
class BIT:
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 5)
def add(self, pos, val):
while pos <= self.n:
self.tree[pos] += val
pos += pos & -pos
def query(self, pos):
res = 0
while pos > 0:
res += self.tree[pos]
pos -= pos & -pos
return res
def query_range(self, l, r):
return self.query(r) - self.query(l - 1)
def kth(self, k):
ans = res = 0
i = 1 << (self.n.bit_length() - 1)
while i > 0:
ans += i
if ans < self.n and res + self.tree[ans] < k:
res += self.tree[ans]
else:
ans -= i
i >>= 1
return ans + 1
n, x, y = map(int, input().split())
a = [0] + list(map(int, input().split()))
# 离散化
ls = {}
sl = [0] * (n + 1)
for i in range(1, n + 1):
ls[a[i]] = 0
ls_cnt = 0
for k in sorted(ls.keys()):
ls_cnt += 1
ls[k] = ls_cnt
sl[ls_cnt] = k
# 初始化树状数组
cnt = BIT(n + 5)
num = BIT(n + 5)
for i in range(1, n + 1):
cnt.add(ls[a[i]], 1)
num.add(ls[a[i]], a[i])
ans = float('inf')
for i in range(1, n + 1):
idx = cnt.kth((n - i) // 2 + 1)
r1 = cnt.query_range(1, idx) * sl[idx] - num.query_range(1, idx)
r2 = num.query_range(idx, n) - cnt.query_range(idx, n) * sl[idx]
r3 = i - 1
ans = min(ans, r1 * y + r2 * y + r3 * x)
cnt.add(ls[a[i]], -1)
num.add(ls[a[i]], -a[i])
print(ans)
算法及复杂度
- 算法:树状数组 + 离散化
- 时间复杂度:
- 每次操作需要
的时间
- 空间复杂度:
- 需要存储树状数组和离散化映射
注意:
- 需要使用__int128_t避免中间计算溢出
- 使用离散化处理数组元素
- 需要两个树状数组分别维护个数和元素和
- 注意1-based索引的处理