分布式计算任务调度
题意
有 个计算任务(编号 1 到
),每个任务有一个计算量
。需要把这些任务分配给
个处理节点,满足:
- 分配给同一节点的任务 ID 必须连续
- 不同节点之间任务 ID 必须递增(即把任务数组切成连续的若干段)
- 单个任务不可拆分
目标:使负载最高节点与最低节点的差异最小,输出最优方案下负载最高节点的负载量。
思路
把 个任务按顺序切成
段连续子数组,要求最大段和尽量小——这不就是经典的"分割数组的最大值"吗?
为什么最小化最大值就等价于最小化最大最小之差?直觉上想:总量固定,把最大的压下来,各段就更均匀,差值自然更小。严格来说,当最大值取到最小时,最小值不会变得更差(分配更均匀了),所以差值也达到最小。
怎么求"最小的最大段和"?二分答案。
二分的思路是这样的:假设我们限定每个节点最多承担 的负载,能不能把所有任务分成不超过
段?
- 如果能,说明
够大,试试更小的值,
hi = mid - 如果不能,说明
太小了,
lo = mid + 1
验证函数怎么写?贪心地从左到右扫,往当前段里塞任务,塞不下了就开新段。最后看段数是否 。
二分的下界是 (至少要装得下最大的那个任务),上界是
(全塞到一个节点)。
特殊情况: 时,每个任务独占一个节点,答案就是
。
时间复杂度 ,其中
。
代码
#include <bits/stdc++.h>
using namespace std;
int main(){
int m, n;
scanf("%d%d", &m, &n);
vector<long long> a(m);
long long lo = 0, hi = 0;
for(int i = 0; i < m; i++){
scanf("%lld", &a[i]);
if(a[i] > lo) lo = a[i];
hi += a[i];
}
if(n >= m){
printf("%lld\n", lo);
return 0;
}
auto check = [&](long long mid) -> bool {
int cnt = 1;
long long cur = 0;
for(int i = 0; i < m; i++){
if(cur + a[i] > mid){
cnt++;
cur = a[i];
if(cnt > n) return false;
} else {
cur += a[i];
}
}
return true;
};
while(lo < hi){
long long mid = (lo + hi) / 2;
if(check(mid)) hi = mid;
else lo = mid + 1;
}
printf("%lld\n", lo);
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int m = sc.nextInt();
int n = sc.nextInt();
long[] a = new long[m];
long lo = 0, hi = 0;
for (int i = 0; i < m; i++) {
a[i] = sc.nextLong();
if (a[i] > lo) lo = a[i];
hi += a[i];
}
if (n >= m) {
System.out.println(lo);
return;
}
while (lo < hi) {
long mid = (lo + hi) / 2;
if (check(a, n, mid)) hi = mid;
else lo = mid + 1;
}
System.out.println(lo);
}
static boolean check(long[] a, int n, long mid) {
int cnt = 1;
long cur = 0;
for (int i = 0; i < a.length; i++) {
if (cur + a[i] > mid) {
cnt++;
cur = a[i];
if (cnt > n) return false;
} else {
cur += a[i];
}
}
return true;
}
}
import sys
input = sys.stdin.readline
def main():
m, n = map(int, input().split())
a = list(map(int, input().split()))
if n >= m:
print(max(a))
return
lo, hi = max(a), sum(a)
def check(mid):
cnt = 1
cur = 0
for x in a:
if cur + x > mid:
cnt += 1
cur = x
if cnt > n:
return False
else:
cur += x
return True
while lo < hi:
mid = (lo + hi) // 2
if check(mid):
hi = mid
else:
lo = mid + 1
print(lo)
main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', line => lines.push(line.trim()));
rl.on('close', () => {
const [m, n] = lines[0].split(' ').map(Number);
const a = lines[1].split(' ').map(Number);
if (n >= m) {
console.log(Math.max(...a));
return;
}
let lo = Math.max(...a);
let hi = a.reduce((s, x) => s + x, 0);
function check(mid) {
let cnt = 1, cur = 0;
for (let i = 0; i < m; i++) {
if (cur + a[i] > mid) {
cnt++;
cur = a[i];
if (cnt > n) return false;
} else {
cur += a[i];
}
}
return true;
}
while (lo < hi) {
const mid = Math.floor((lo + hi) / 2);
if (check(mid)) hi = mid;
else lo = mid + 1;
}
console.log(lo);
});

京公网安备 11010502036488号