任务批处理负载均衡

题意

个连续任务按顺序划分成 个批次(不能打乱顺序),每个批次有一个总计算成本(该批次内所有任务成本之和)。要求找到一种划分方案,使得 个批次的总计算成本的标准差最小。输出每个批次包含的任务数量。

思路

标准差最小,等价于什么?

因为均值是固定的(总成本除以 ,跟怎么分无关),所以最小化标准差就是最小化方差,也就是最小化:

$$

其中 是第 批的总成本, 是均值。

问题转化成:把数组按顺序切 段,每段和尽量接近均值。这是一个经典的区间 DP

定义 为前 个任务分成 个批次时,所有批次的 之和的最小值。转移方程:

$$

枚举上一段的结束位置 ,把第 到第 个任务作为第 批。

初始状态 ,答案在

为了输出方案,再开一个 记录最优转移来源,最后倒推即可还原每段的长度。

代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    int n, k;
    scanf("%d%d", &n, &k);
    vector<long long> pre(n+1, 0);
    for(int i = 1; i <= n; i++){
        long long x; scanf("%lld", &x);
        pre[i] = pre[i-1] + x;
    }
    double mean = (double)pre[n] / k;

    vector<vector<double>> dp(n+1, vector<double>(k+1, 1e18));
    vector<vector<int>> from(n+1, vector<int>(k+1, 0));
    dp[0][0] = 0;

    for(int j = 1; j <= k; j++){
        for(int i = j; i <= n - (k - j); i++){
            for(int p = j-1; p < i; p++){
                double s = (double)(pre[i] - pre[p]) - mean;
                double cost = dp[p][j-1] + s * s;
                if(cost < dp[i][j] - 1e-9){
                    dp[i][j] = cost;
                    from[i][j] = p;
                }
            }
        }
    }

    vector<int> splits;
    int cur = n, g = k;
    while(g > 0){
        int p = from[cur][g];
        splits.push_back(cur - p);
        cur = p; g--;
    }
    reverse(splits.begin(), splits.end());
    for(int i = 0; i < k; i++){
        if(i) printf(" ");
        printf("%d", splits[i]);
    }
    printf("\n");
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt(), k = sc.nextInt();
        long[] pre = new long[n + 1];
        for (int i = 1; i <= n; i++) pre[i] = pre[i - 1] + sc.nextLong();
        double mean = (double) pre[n] / k;

        double[][] dp = new double[n + 1][k + 1];
        int[][] from = new int[n + 1][k + 1];
        for (double[] row : dp) Arrays.fill(row, 1e18);
        dp[0][0] = 0;

        for (int j = 1; j <= k; j++)
            for (int i = j; i <= n - (k - j); i++)
                for (int p = j - 1; p < i; p++) {
                    double s = (double)(pre[i] - pre[p]) - mean;
                    double cost = dp[p][j - 1] + s * s;
                    if (cost < dp[i][j] - 1e-9) {
                        dp[i][j] = cost;
                        from[i][j] = p;
                    }
                }

        int[] splits = new int[k];
        int cur = n, g = k;
        while (g > 0) { int p = from[cur][g]; splits[--g] = cur - p; cur = p; }

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < k; i++) { if (i > 0) sb.append(' '); sb.append(splits[i]); }
        System.out.println(sb);
    }
}
import sys
input = sys.stdin.readline

def main():
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    pre = [0] * (n + 1)
    for i in range(n):
        pre[i + 1] = pre[i] + a[i]
    mean = pre[n] / k

    INF = float('inf')
    dp = [[INF] * (k + 1) for _ in range(n + 1)]
    fr = [[0] * (k + 1) for _ in range(n + 1)]
    dp[0][0] = 0

    for j in range(1, k + 1):
        for i in range(j, n - (k - j) + 1):
            for p in range(j - 1, i):
                s = (pre[i] - pre[p]) - mean
                cost = dp[p][j - 1] + s * s
                if cost < dp[i][j] - 1e-9:
                    dp[i][j] = cost
                    fr[i][j] = p

    splits = []
    cur, g = n, k
    while g > 0:
        p = fr[cur][g]
        splits.append(cur - p)
        cur = p; g -= 1
    splits.reverse()
    print(' '.join(map(str, splits)))

main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l.trim()));
rl.on('close', () => {
    const [n, k] = lines[0].split(' ').map(Number);
    const a = lines[1].split(' ').map(Number);
    const pre = new Array(n + 1).fill(0);
    for (let i = 1; i <= n; i++) pre[i] = pre[i - 1] + a[i - 1];
    const mean = pre[n] / k;

    const dp = Array.from({length: n + 1}, () => new Float64Array(k + 1).fill(1e18));
    const from = Array.from({length: n + 1}, () => new Int32Array(k + 1));
    dp[0][0] = 0;

    for (let j = 1; j <= k; j++)
        for (let i = j; i <= n - (k - j); i++)
            for (let p = j - 1; p < i; p++) {
                const s = (pre[i] - pre[p]) - mean;
                const cost = dp[p][j - 1] + s * s;
                if (cost < dp[i][j] - 1e-9) {
                    dp[i][j] = cost;
                    from[i][j] = p;
                }
            }

    const splits = [];
    let cur = n, g = k;
    while (g > 0) { const p = from[cur][g]; splits.push(cur - p); cur = p; g--; }
    splits.reverse();
    console.log(splits.join(' '));
});

复杂度

  • 时间复杂度:,三重循环枚举批次、结束位置、分割点。
  • 空间复杂度:,存储 数组。