任务批处理负载均衡
题意
把 个连续任务按顺序划分成
个批次(不能打乱顺序),每个批次有一个总计算成本(该批次内所有任务成本之和)。要求找到一种划分方案,使得
个批次的总计算成本的标准差最小。输出每个批次包含的任务数量。
思路
标准差最小,等价于什么?
因为均值是固定的(总成本除以 ,跟怎么分无关),所以最小化标准差就是最小化方差,也就是最小化:
$$
其中 是第
批的总成本,
是均值。
问题转化成:把数组按顺序切 段,每段和尽量接近均值。这是一个经典的区间 DP。
定义 为前
个任务分成
个批次时,所有批次的
之和的最小值。转移方程:
$$
枚举上一段的结束位置 ,把第
到第
个任务作为第
批。
初始状态 ,答案在
。
为了输出方案,再开一个 记录最优转移来源,最后倒推即可还原每段的长度。
代码
#include <bits/stdc++.h>
using namespace std;
int main(){
int n, k;
scanf("%d%d", &n, &k);
vector<long long> pre(n+1, 0);
for(int i = 1; i <= n; i++){
long long x; scanf("%lld", &x);
pre[i] = pre[i-1] + x;
}
double mean = (double)pre[n] / k;
vector<vector<double>> dp(n+1, vector<double>(k+1, 1e18));
vector<vector<int>> from(n+1, vector<int>(k+1, 0));
dp[0][0] = 0;
for(int j = 1; j <= k; j++){
for(int i = j; i <= n - (k - j); i++){
for(int p = j-1; p < i; p++){
double s = (double)(pre[i] - pre[p]) - mean;
double cost = dp[p][j-1] + s * s;
if(cost < dp[i][j] - 1e-9){
dp[i][j] = cost;
from[i][j] = p;
}
}
}
}
vector<int> splits;
int cur = n, g = k;
while(g > 0){
int p = from[cur][g];
splits.push_back(cur - p);
cur = p; g--;
}
reverse(splits.begin(), splits.end());
for(int i = 0; i < k; i++){
if(i) printf(" ");
printf("%d", splits[i]);
}
printf("\n");
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt(), k = sc.nextInt();
long[] pre = new long[n + 1];
for (int i = 1; i <= n; i++) pre[i] = pre[i - 1] + sc.nextLong();
double mean = (double) pre[n] / k;
double[][] dp = new double[n + 1][k + 1];
int[][] from = new int[n + 1][k + 1];
for (double[] row : dp) Arrays.fill(row, 1e18);
dp[0][0] = 0;
for (int j = 1; j <= k; j++)
for (int i = j; i <= n - (k - j); i++)
for (int p = j - 1; p < i; p++) {
double s = (double)(pre[i] - pre[p]) - mean;
double cost = dp[p][j - 1] + s * s;
if (cost < dp[i][j] - 1e-9) {
dp[i][j] = cost;
from[i][j] = p;
}
}
int[] splits = new int[k];
int cur = n, g = k;
while (g > 0) { int p = from[cur][g]; splits[--g] = cur - p; cur = p; }
StringBuilder sb = new StringBuilder();
for (int i = 0; i < k; i++) { if (i > 0) sb.append(' '); sb.append(splits[i]); }
System.out.println(sb);
}
}
import sys
input = sys.stdin.readline
def main():
n, k = map(int, input().split())
a = list(map(int, input().split()))
pre = [0] * (n + 1)
for i in range(n):
pre[i + 1] = pre[i] + a[i]
mean = pre[n] / k
INF = float('inf')
dp = [[INF] * (k + 1) for _ in range(n + 1)]
fr = [[0] * (k + 1) for _ in range(n + 1)]
dp[0][0] = 0
for j in range(1, k + 1):
for i in range(j, n - (k - j) + 1):
for p in range(j - 1, i):
s = (pre[i] - pre[p]) - mean
cost = dp[p][j - 1] + s * s
if cost < dp[i][j] - 1e-9:
dp[i][j] = cost
fr[i][j] = p
splits = []
cur, g = n, k
while g > 0:
p = fr[cur][g]
splits.append(cur - p)
cur = p; g -= 1
splits.reverse()
print(' '.join(map(str, splits)))
main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l.trim()));
rl.on('close', () => {
const [n, k] = lines[0].split(' ').map(Number);
const a = lines[1].split(' ').map(Number);
const pre = new Array(n + 1).fill(0);
for (let i = 1; i <= n; i++) pre[i] = pre[i - 1] + a[i - 1];
const mean = pre[n] / k;
const dp = Array.from({length: n + 1}, () => new Float64Array(k + 1).fill(1e18));
const from = Array.from({length: n + 1}, () => new Int32Array(k + 1));
dp[0][0] = 0;
for (let j = 1; j <= k; j++)
for (let i = j; i <= n - (k - j); i++)
for (let p = j - 1; p < i; p++) {
const s = (pre[i] - pre[p]) - mean;
const cost = dp[p][j - 1] + s * s;
if (cost < dp[i][j] - 1e-9) {
dp[i][j] = cost;
from[i][j] = p;
}
}
const splits = [];
let cur = n, g = k;
while (g > 0) { const p = from[cur][g]; splits.push(cur - p); cur = p; g--; }
splits.reverse();
console.log(splits.join(' '));
});
复杂度
- 时间复杂度:
,三重循环枚举批次、结束位置、分割点。
- 空间复杂度:
,存储
和
数组。

京公网安备 11010502036488号