题目链接

小红的数组操作

题目描述

小红拿到了一个数组,她准备对这个数组进行如下两种操作:

  1. 删除数组的第一个元素。该操作的花费为
  2. 使得数组中任意一个元素加 1 或者减 1。该操作的花费为

小红希望用尽可能少的花费使得数组所有元素都相等,请你帮小红求出最小的花费。

输入:

  • 第一行输入三个正整数 ,代表数组的大小以及两种操作的花费
  • 第二行输入 个正整数,代表数组的元素

输出:

  • 输出一个整数,代表将所有元素变成相等的最小总花费

解题思路

这是一个动态查询问题,可以通过以下步骤解决:

  1. 关键发现:

    • 对于每个子数组,最优的目标值是中位数
    • 需要动态维护元素个数和元素和
    • 需要能快速查询第k大的元素
  2. 解题策略:

    • 使用树状数组维护元素个数和元素和
    • 离散化处理数组元素
    • 使用二分查找第k大的元素
  3. 具体步骤:

    • 对数组元素进行离散化处理
    • 使用两个树状数组分别维护个数和元素和
    • 依次删除前缀,计算每种情况的代价

代码

#include <iostream>
#include <vector>
#include <map>
using namespace std;
using ll=long long;
const int N=100100;
int a[N];
int n;
ll x,y;
struct BIT{
	vector<ll> tr1;
	int n;
	
	BIT(int n=200100):n(n),tr1(n+5){}
	
	void add(int now,ll val){
		for(int i=now;i<=n;i+=i&-i){
			tr1[i]+=val;
		}
	}
	
	ll query(int now){
		ll res=0;
		for(int i=now;i>0;i-=i&-i)
			res+=tr1[i];
		return res;
	}
	
	ll query(int l,int r){
		return query(r)-query(l-1);
	}
	
	int kth(int k){
		int ans=0,res=0;
		for(int i=1<<__lg(n);i>0;i>>=1){
			ans+=i;
			if(ans<n&&res+tr1[ans]<k)
				res+=tr1[ans];
			else
				ans-=i;
		}
		return ans+1;
	}
};

map<int,int> ls;
int sl[N];
int ls_cnt=0;
void init()
{
	for(int i=1;i<=n;++i)
		ls[a[i]]=0;
	for(auto &p:ls)
	{
		p.second=++ls_cnt;
		sl[ls_cnt]=p.first;
	}
}

int main(void)
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	__int128_t ans=9e18,r1,r2,r3;
	int i,idx;
	cin>>n>>x>>y;
	for(i=1;i<=n;++i)
		cin>>a[i];
	init();
	BIT cnt(n+5),num(n+5);
	for(i=1;i<=n;++i)
	{
		cnt.add(ls[a[i]],1);
		num.add(ls[a[i]],a[i]);
	}
	for(i=1;i<=n;++i)
	{
		idx=cnt.kth((n-i)/2+1);
		r1=cnt.query(1,idx)*sl[idx]-num.query(1,idx);
		r2=num.query(idx,n)-cnt.query(idx,n)*sl[idx];
		r3=i-1;
		ans=min(ans,r1*y+r2*y+r3*x);
		cnt.add(ls[a[i]],-1);
		num.add(ls[a[i]],-a[i]);
	}
	cout<<(ll)ans;
	return 0;
}
import java.util.*;
import java.io.*;
import java.math.BigInteger;

public class Main {
    static class BIT {
        long[] tree;
        int n;
        
        BIT(int n) {
            this.n = n;
            tree = new long[n + 5];
        }
        
        void add(int pos, long val) {
            for(int i = pos; i <= n; i += i & -i) {
                tree[i] += val;
            }
        }
        
        long query(int pos) {
            long res = 0;
            for(int i = pos; i > 0; i -= i & -i) {
                res += tree[i];
            }
            return res;
        }
        
        long query(int l, int r) {
            return query(r) - query(l - 1);
        }
        
        int kth(int k) {
            int ans = 0, res = 0;
            for(int i = 1 << (31 - Integer.numberOfLeadingZeros(n)); i > 0; i >>= 1) {
                ans += i;
                if(ans < n && res + tree[ans] < k) {
                    res += tree[ans];
                } else {
                    ans -= i;
                }
            }
            return ans + 1;
        }
    }
    
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] line = br.readLine().split(" ");
        int n = Integer.parseInt(line[0]);
        BigInteger x = new BigInteger(line[1]);
        BigInteger y = new BigInteger(line[2]);
        
        int[] a = new int[n + 1];
        line = br.readLine().split(" ");
        TreeMap<Integer, Integer> map = new TreeMap<>();
        for(int i = 1; i <= n; i++) {
            a[i] = Integer.parseInt(line[i - 1]);
            map.put(a[i], 0);
        }
        
        int idx = 1;
        int[] sl = new int[n + 1];
        for(Map.Entry<Integer, Integer> entry : map.entrySet()) {
            entry.setValue(idx);
            sl[idx] = entry.getKey();
            idx++;
        }
        
        BIT cnt = new BIT(n + 5);
        BIT num = new BIT(n + 5);
        for(int i = 1; i <= n; i++) {
            cnt.add(map.get(a[i]), 1);
            num.add(map.get(a[i]), a[i]);
        }
        
        BigInteger ans = null;
        for(int i = 1; i <= n; i++) {
            int mid = cnt.kth((n - i) / 2 + 1);
            BigInteger r1 = BigInteger.valueOf(cnt.query(1, mid))
                .multiply(BigInteger.valueOf(sl[mid]))
                .subtract(BigInteger.valueOf(num.query(1, mid)));
            BigInteger r2 = BigInteger.valueOf(num.query(mid, n))
                .subtract(BigInteger.valueOf(cnt.query(mid, n))
                .multiply(BigInteger.valueOf(sl[mid])));
            BigInteger r3 = BigInteger.valueOf(i - 1);
            
            BigInteger curr = r1.multiply(y).add(r2.multiply(y)).add(r3.multiply(x));
            if(ans == null || curr.compareTo(ans) < 0) {
                ans = curr;
            }
            
            cnt.add(map.get(a[i]), -1);
            num.add(map.get(a[i]), -a[i]);
        }
        
        System.out.println(ans);
    }
}
class BIT:
    def __init__(self, n):
        self.n = n
        self.tree = [0] * (n + 5)
    
    def add(self, pos, val):
        while pos <= self.n:
            self.tree[pos] += val
            pos += pos & -pos
    
    def query(self, pos):
        res = 0
        while pos > 0:
            res += self.tree[pos]
            pos -= pos & -pos
        return res
    
    def query_range(self, l, r):
        return self.query(r) - self.query(l - 1)
    
    def kth(self, k):
        ans = res = 0
        i = 1 << (self.n.bit_length() - 1)
        while i > 0:
            ans += i
            if ans < self.n and res + self.tree[ans] < k:
                res += self.tree[ans]
            else:
                ans -= i
            i >>= 1
        return ans + 1

n, x, y = map(int, input().split())
a = [0] + list(map(int, input().split()))

# 离散化
ls = {}
sl = [0] * (n + 1)
for i in range(1, n + 1):
    ls[a[i]] = 0
ls_cnt = 0
for k in sorted(ls.keys()):
    ls_cnt += 1
    ls[k] = ls_cnt
    sl[ls_cnt] = k

# 初始化树状数组
cnt = BIT(n + 5)
num = BIT(n + 5)
for i in range(1, n + 1):
    cnt.add(ls[a[i]], 1)
    num.add(ls[a[i]], a[i])

ans = float('inf')
for i in range(1, n + 1):
    idx = cnt.kth((n - i) // 2 + 1)
    r1 = cnt.query_range(1, idx) * sl[idx] - num.query_range(1, idx)
    r2 = num.query_range(idx, n) - cnt.query_range(idx, n) * sl[idx]
    r3 = i - 1
    ans = min(ans, r1 * y + r2 * y + r3 * x)
    cnt.add(ls[a[i]], -1)
    num.add(ls[a[i]], -a[i])

print(ans)

算法及复杂度

  • 算法:树状数组 + 离散化
  • 时间复杂度: - 每次操作需要 的时间
  • 空间复杂度: - 需要存储树状数组和离散化映射

注意:

  1. 需要使用__int128_t避免中间计算溢出
  2. 使用离散化处理数组元素
  3. 需要两个树状数组分别维护个数和元素和
  4. 注意1-based索引的处理