题目链接

MOE Top-k 路由

题目描述

本题模拟一个稀疏 MOE (Mixture of Experts) 模型中的路由选择过程。给定 个专家和 张 NPU 卡,需要按照以下三步规则,从所有专家中选出最终的 个作为路由目标:

  1. 分组与代表选举

    • 首先,将 个专家平均分配到 张卡上,形成 个组。
    • 在每个组内部,选出概率最高的专家作为该组的“代表”。
  2. 顶层路由(选组)

    • 对所有组的“代表”进行排序(按概率降序,概率相同时组号小的优先)。
    • 从排序后的结果中,选取前 个组。
  3. 二层路由(选专家)

    • 将上一步选出的 个组中的所有专家汇集起来。
    • 对这些专家进行排序(按概率降序,概率相同时专家编号小的优先)。
    • 选取前 位专家作为最终结果。

同时,需要处理一些约束和异常情况,如无法平均分组、参数不合理等。

解题思路

这是一个多阶段的排序和筛选问题。解题的关键在于清晰地实现每一步的逻辑,并正确处理排序的复合条件。

  1. 输入与异常检查

    • 首先读取四个整数 个专家的概率。
    • 立即进行异常检查,这是最优先的步骤:
      • 无法整除:如果 ,则无法平均分组,输出 error
      • 选组数超限:如果 ,要选择的组数比总组数还多,不合逻辑,输出 error
      • 可选专家不足:每组专家数 个组总共有 个专家。如果这个数量小于 ,即 ,那么不可能选出 个专家,输出 error
    • 如果任何一个条件满足,程序直接结束。
  2. 分组并找出代表

    • 计算每组的专家数量
    • 创建一个数据结构(例如,一个列表或数组)来存储每个组的代表信息。这个结构需要包含:组的原始编号、代表的概率值、代表的专家编号。
    • 遍历 个组(从组 到组 ):
      • 对于每个组 ,其专家编号范围是 [i*g, (i+1)*g - 1]
      • 在这个范围内寻找概率最高的专家,记录下其概率和编号。
      • 将(组号 ,最大概率,对应专家编号)存入代表列表中。
  3. 排序并选择Top-p组

    • 对上一步生成的代表列表进行排序。排序规则为:
      • 主要按概率降序排列。
      • 如果概率相同,则按组的原始编号升序排列。
    • 排序后,选取列表中的前 项。这 项对应的组就是我们接下来要处理的组。
  4. 汇集、排序并选择Top-k专家

    • 创建一个新的列表,用于存放候选专家。
    • 遍历刚才选出的 个组。
    • 对于每个被选中的组,将其包含的所有专家(及其概率和原始编号)都添加到候选专家列表中。
    • 对这个候选专家列表进行排序。排序规则为:
      • 主要按概率降序排列。
      • 如果概率相同,则按专家的原始编号升序排列。
    • 排序后,选取列表中的前 项。
  5. 输出结果

    • 从最终选出的 个专家中,提取出他们的原始编号。
    • 对这 个编号进行升序排序。
    • 按格式要求,用空格分隔输出。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

struct Expert {
    int id;
    double prob;
};

struct GroupRep {
    int id;
    double max_prob;
    int expert_id;
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m, p, k;
    cin >> n >> m >> p >> k;

    if (n % m != 0 || p > m) {
        cout << "error\n";
        return 0;
    }

    int g = n / m;
    if (1LL * p * g < k) {
        cout << "error\n";
        return 0;
    }

    vector<Expert> experts(n);
    for (int i = 0; i < n; ++i) {
        experts[i].id = i;
        cin >> experts[i].prob;
    }

    vector<GroupRep> reps;
    for (int i = 0; i < m; ++i) {
        int start_idx = i * g;
        int end_idx = start_idx + g;
        
        int max_expert_id = start_idx;
        double max_prob = experts[start_idx].prob;

        for (int j = start_idx + 1; j < end_idx; ++j) {
            if (experts[j].prob > max_prob) {
                max_prob = experts[j].prob;
                max_expert_id = j;
            }
        }
        reps.push_back({i, max_prob, max_expert_id});
    }

    sort(reps.begin(), reps.end(), [](const GroupRep& a, const GroupRep& b) {
        if (a.max_prob != b.max_prob) {
            return a.max_prob > b.max_prob;
        }
        return a.id < b.id;
    });

    vector<Expert> candidate_experts;
    for (int i = 0; i < p; ++i) {
        int group_id = reps[i].id;
        int start_idx = group_id * g;
        int end_idx = start_idx + g;
        for (int j = start_idx; j < end_idx; ++j) {
            candidate_experts.push_back(experts[j]);
        }
    }

    sort(candidate_experts.begin(), candidate_experts.end(), [](const Expert& a, const Expert& b) {
        if (a.prob != b.prob) {
            return a.prob > b.prob;
        }
        return a.id < b.id;
    });

    vector<int> result_ids;
    for (int i = 0; i < k; ++i) {
        result_ids.push_back(candidate_experts[i].id);
    }

    sort(result_ids.begin(), result_ids.end());

    for (int i = 0; i < k; ++i) {
        cout << result_ids[i] << (i == k - 1 ? "" : " ");
    }
    cout << "\n";

    return 0;
}
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Scanner;

public class Main {
    static class Expert {
        int id;
        double prob;
        Expert(int id, double prob) {
            this.id = id;
            this.prob = prob;
        }
    }

    static class GroupRep {
        int id;
        double maxProb;
        int expertId;
        GroupRep(int id, double maxProb, int expertId) {
            this.id = id;
            this.maxProb = maxProb;
            this.expertId = expertId;
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int p = sc.nextInt();
        int k = sc.nextInt();

        if (n % m != 0 || p > m) {
            System.out.println("error");
            return;
        }

        int g = n / m;
        if ((long) p * g < k) {
            System.out.println("error");
            return;
        }

        List<Expert> experts = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            experts.add(new Expert(i, sc.nextDouble()));
        }

        List<GroupRep> reps = new ArrayList<>();
        for (int i = 0; i < m; i++) {
            int startIdx = i * g;
            
            int maxExpertId = startIdx;
            double maxProb = experts.get(startIdx).prob;

            for (int j = startIdx + 1; j < startIdx + g; j++) {
                if (experts.get(j).prob > maxProb) {
                    maxProb = experts.get(j).prob;
                    maxExpertId = j;
                }
            }
            reps.add(new GroupRep(i, maxProb, maxExpertId));
        }

        reps.sort((a, b) -> {
            if (a.maxProb != b.maxProb) {
                return Double.compare(b.maxProb, a.maxProb);
            }
            return Integer.compare(a.id, b.id);
        });

        List<Expert> candidateExperts = new ArrayList<>();
        for (int i = 0; i < p; i++) {
            int groupId = reps.get(i).id;
            int startIdx = groupId * g;
            for (int j = 0; j < g; j++) {
                candidateExperts.add(experts.get(startIdx + j));
            }
        }

        candidateExperts.sort((a, b) -> {
            if (a.prob != b.prob) {
                return Double.compare(b.prob, a.prob);
            }
            return Integer.compare(a.id, b.id);
        });

        List<Integer> resultIds = new ArrayList<>();
        for (int i = 0; i < k; i++) {
            resultIds.add(candidateExperts.get(i).id);
        }

        Collections.sort(resultIds);

        for (int i = 0; i < k; i++) {
            System.out.print(resultIds.get(i) + (i == k - 1 ? "" : " "));
        }
        System.out.println();
    }
}
def solve():
    try:
        n, m, p, k = map(int, input().split())
        probs = list(map(float, input().split()))
    except (IOError, ValueError):
        # 处理可能的空输入行
        return

    if n % m != 0 or p > m:
        print("error")
        return

    g = n // m
    if p * g < k:
        print("error")
        return

    experts = [{'id': i, 'prob': prob} for i, prob in enumerate(probs)]

    reps = []
    for i in range(m):
        start_idx = i * g
        group_experts = experts[start_idx : start_idx + g]
        
        # 找到组内概率最大的专家
        # python的max函数在处理(value, key)元组时,如果value相同,会比较key,
        # 但我们需要的是编号最小的,所以用-id来反转编号的比较
        # 更好的方式是写一个稳定的查找
        max_prob_expert = group_experts[0]
        for expert in group_experts[1:]:
            if expert['prob'] > max_prob_expert['prob']:
                max_prob_expert = expert
        
        reps.append({'id': i, 'max_prob': max_prob_expert['prob'], 'expert_id': max_prob_expert['id']})

    reps.sort(key=lambda x: (-x['max_prob'], x['id']))

    candidate_experts = []
    for i in range(p):
        group_id = reps[i]['id']
        start_idx = group_id * g
        candidate_experts.extend(experts[start_idx : start_idx + g])

    candidate_experts.sort(key=lambda x: (-x['prob'], x['id']))

    result_ids = [candidate_experts[i]['id'] for i in range(k)]
    result_ids.sort()

    print(' '.join(map(str, result_ids)))

solve()

算法及复杂度

  • 算法:模拟、排序
  • 时间复杂度,其中
    • :读取输入和寻找各组代表。
    • :对 个组代表进行排序。
    • :对 个候选专家进行排序,这是最主要的时间开销。
    • :对最终结果进行排序。
  • 空间复杂度,主要用于存储所有专家的信息和候选专家的列表。