题目链接

分布式计算任务调度

题目描述

给定 个任务和 个计算节点。每个任务 有一个计算量 。需要将这 个任务分配给 个节点,并遵循以下约束:

  1. 顺序性: 分配给节点 的任务 ID 必须全部小于分配给节点 的任务 ID。
  2. 连续性: 分配给同一个节点的任务 ID 必须是连续的。

目标是找到一种分配方案,使得负载最高的节点的总负载(该节点上所有任务的计算量之和)最小。输出这个最小化的最大负载值。

解题思路

这个问题的核心是“最小化最大负载”,这是应用二分答案 (Binary Search on the Answer) 算法的典型特征。

问题的本质是:将一个包含 个任务计算量的数组,分割成 个连续的子数组,使得这些子数组的和中的最大值最小。

我们不直接去求解这个“最小的最大负载”,而是去猜测它。假设我们猜测这个值是 max_load_limit,然后设计一个函数 check(max_load_limit) 来验证:是否存在一种分割方案,使得每个节点的负载都不超过 max_load_limit,并且使用的节点数不多于 个。

  1. 二分答案的范围

    • 下界 (left): 答案不可能小于任何单个任务的计算量,所以下界是所有 中的最大值。
    • 上界 (right): 答案最多是所有任务的计算量之和(当 时)。
    • 我们就在 [max(L_i), sum(L_i)] 这个区间内进行二分查找。
  2. check(limit) 函数的实现

    • 这个函数的目标是判断:在单个节点负载不能超过 limit 的前提下,最少需要多少个节点才能分配完所有任务。
    • 我们可以使用贪心策略来实现 check 函数: a. 初始化 nodes_needed = 1current_load = 0。 b. 从第一个任务开始遍历: - 对于当前任务 ,如果 current_load + L_i <= limit,则将其分配给当前节点,并更新 current_load += L_i。 - 否则,当前节点已无法容纳该任务,必须启用一个新节点。我们让 nodes_needed++,并将当前任务分配给这个新节点,current_load 重置为 。 c. 遍历结束后,nodes_needed 就是在 limit 约束下所需的最少节点数。 d. check 函数返回 nodes_needed <= N
  3. 二分查找的逻辑

    • check(mid) 返回 true 时,说明 mid 这个最大负载限制是可行的(甚至可能还有更优的、更小的解)。我们记录下这个可能的结果 ans = mid,然后尝试在更小的范围 [left, mid - 1] 中寻找更优解。
    • check(mid) 返回 false 时,说明 mid 这个最大负载限制太小了,无法在 个节点内完成分配。我们需要放宽限制,因此在更大的范围 [mid + 1, right] 中继续查找。

最终,二分查找结束后记录的 ans 就是我们要求的最小化的最大负载。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

// 检查在最大负载为 limit 的情况下,是否可以在 n 个节点内完成任务
bool check(long long limit, int n, const vector<int>& loads) {
    int nodes_needed = 1;
    long long current_load = 0;
    for (int load : loads) {
        if (load > limit) return false; // 单个任务超过限制,无法分配
        if (current_load + load <= limit) {
            current_load += load;
        } else {
            nodes_needed++;
            current_load = load;
        }
    }
    return nodes_needed <= n;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int m, n;
    cin >> m >> n;
    vector<int> loads(m);
    long long total_load = 0;
    int max_single_load = 0;
    for (int i = 0; i < m; ++i) {
        cin >> loads[i];
        total_load += loads[i];
        if (loads[i] > max_single_load) {
            max_single_load = loads[i];
        }
    }

    long long left = max_single_load;
    long long right = total_load;
    long long ans = right;

    while (left <= right) {
        long long mid = left + (right - left) / 2;
        if (check(mid, n, loads)) {
            ans = mid;
            right = mid - 1;
        } else {
            left = mid + 1;
        }
    }

    cout << ans << endl;

    return 0;
}
import java.util.Scanner;
import java.util.Arrays;

public class Main {

    // 检查在最大负载为 limit 的情况下,是否可以在 n 个节点内完成任务
    private static boolean check(long limit, int n, int[] loads) {
        int nodesNeeded = 1;
        long currentLoad = 0;
        for (int load : loads) {
            if (load > limit) return false;
            if (currentLoad + load <= limit) {
                currentLoad += load;
            } else {
                nodesNeeded++;
                currentLoad = load;
            }
        }
        return nodesNeeded <= n;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int m = sc.nextInt();
        int n = sc.nextInt();
        int[] loads = new int[m];
        
        long totalLoad = 0;
        int maxSingleLoad = 0;
        for (int i = 0; i < m; i++) {
            loads[i] = sc.nextInt();
            totalLoad += loads[i];
            if (loads[i] > maxSingleLoad) {
                maxSingleLoad = loads[i];
            }
        }

        long left = maxSingleLoad;
        long right = totalLoad;
        long ans = right;

        while (left <= right) {
            long mid = left + (right - left) / 2;
            if (check(mid, n, loads)) {
                ans = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        System.out.println(ans);
    }
}
# 检查在最大负载为 limit 的情况下,是否可以在 n 个节点内完成任务
def check(limit, n, loads):
    nodes_needed = 1
    current_load = 0
    for load in loads:
        if load > limit:
            return False
        if current_load + load <= limit:
            current_load += load
        else:
            nodes_needed += 1
            current_load = load
    return nodes_needed <= n

def main():
    m, n = map(int, input().split())
    loads = list(map(int, input().split()))

    # 设置二分查找的边界
    left = max(loads)
    right = sum(loads)
    ans = right

    while left <= right:
        mid = (left + right) // 2
        if check(mid, n, loads):
            ans = mid
            right = mid - 1
        else:
            left = mid + 1
            
    print(ans)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:二分答案、贪心
  • 时间复杂度:,其中 是任务总数, 是所有任务的总计算量。check 函数的时间复杂度是 ,二分查找的范围是 ,因此二分次数约为
  • 空间复杂度:,用于存储所有任务的计算量。