题目描述
有n堆石子,第i堆石子的石子数量是ai,作为牛客网的一头领头牛,牛牛决定把这些石子搬回牛客。如果牛牛一次搬运的石子数量是k,那么这堆石子将对牛牛产生k^2的负担值。牛牛最多只能搬运m次,每次搬运可以从一堆石子中选出一些石子搬回牛客,每次搬运不能同时从两堆石子中选取石子,每次只能搬运整数个石子。牛牛是一只聪明的牛,他想出了一种搬运计划可以最小化他搬运完这些石子的负担值的总和,但是突然牛牛的死敌牛能出现了,牛能每次可以施展以下的魔法:
x v将第x堆石子的数量变为v
这打乱了牛牛的计划,每次牛能施展一次魔法,牛牛就得重新规划他的搬运方案,但是牛能施展魔法的次数太多了,牛牛根本忙活不过来了,于是他请来了聪明的你帮他写一个程序计算。
样例
输入 3 4 2 2 2 3 1 2 1 3 2 6 输出 10 13 31
第一次修改: 2 2 2 最优的是:首先随便挑一堆石子拿走一个石子代价为1 接着每堆石子一次那完,代价2 * 2 + 1 * 1 + 2 * 2 总代价为10 第二次修改: 3 2 2 最优的是:首先从第一堆石子拿走一个石子代价为1 接着每堆石子一次那完,代价2 * 2 + 2 * 2 + 2 * 2 总代价为13 第二次修改: 3 6 2 最优的是:首先从第二堆石子拿走3个石子代价为9 接着每堆石子一次那完,代价3 * 3 + 3 * 3 + 2 * 2 总代价为31
算法
(结论推导 + 线段树优化带修改的dp )
一堆个数为a的石子分m次搬运,每次搬运k个石子代价为k * k,搬完所有石子的最小代价是多少。结论很容易猜到就是,将a个石子平均成m份进行搬运,一下是证明:
有了这个结论之后,很容易能想到dp的做法,先将每一对堆石子用j次搬运的最小代价求出来,接着用一个O(n^3)的转移求出答案,但是本题的dp是带修改的,就需要优化状态转移(后面的我想到了要用线段树优化但是具体如何操作实在是想不出来,参考了其他人的题解),我们发现每次只会修改一个位置,我们用线段树维护一段区间[x1,x2,x3,x4]的用[1 ~ m]次将区间所有数搬运完的最小代价,用O(m^2)的时间初始化单个节点的区间[x],然后用[m^2logn]的时间修改整棵线段树(因为数据范围很小所以可以,太妙了!)总的时间复杂度就是O(qm^2logn)
时间复杂度O(qm^2logn)
C++ 代码
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <unordered_map> #include <vector> #include <queue> #include <set> #include <bitset> #include <cmath> #define P 131 #define lc u << 1 #define rc u << 1 | 1 using namespace std; typedef long long LL; const int N = 410; const LL INF = 1e15; struct Node { int l,r; LL f[N];//表示tr[u]维护的区间用N次搬运的最小代价 }tr[N * 4]; int a[N]; int n,m; LL calc(int x,int k) { if(k == 0) return 1e15; if(x < k) return 1e15; if(x % k == 0) return k * 1ll * (x / k) * (x / k); else return (x % k) * 1ll * ((x + k - 1) / k) * ((x + k - 1) / k) + (k - x % k) * 1ll * (x / k) * (x / k); } void pushup(int u) { for(int i = 1;i <= m;i ++) { tr[u].f[i] = INF; for(int j = 1;j < i;j ++) tr[u].f[i] = min(tr[u].f[i],tr[lc].f[i - j] + tr[rc].f[j]); } } void build(int u,int l,int r) { if(l == r) { tr[u] = {l,r}; for(int i = 1;i <= m;i ++) tr[u].f[i] = calc(a[l],i); return; } tr[u] = {l,r}; int mid = l + r >> 1; build(lc,l,mid); build(rc,mid + 1,r); pushup(u); } void modify(int u,int x,int a) { if(tr[u].l == x && tr[u].r == x) { for(int i = 1;i <= m;i ++) tr[u].f[i] = calc(a,i); return; } int mid = (tr[u].l + tr[u].r) >> 1; if(x <= mid) modify(lc,x,a); else modify(rc,x,a); pushup(u); } void solve() { scanf("%d%d",&n,&m); for(int i = 1;i <= n;i ++) scanf("%d",&a[i]); build(1,1,n); int q; scanf("%d",&q); while(q --) { int x,y; scanf("%d%d",&x,&y); modify(1,x,y); printf("%lld\n",tr[1].f[m]); } } int main() { #ifdef LOCAL freopen("in.txt", "r", stdin); freopen("out.txt", "w", stdout); #else #endif // LOCAL int T = 1; // init(500); // scanf("%d",&T); while(T --) { // scanf("%lld%lld",&n,&m); solve(); // test(); } return 0; }
是个好题!!!