题目链接
题目描述
本题模拟一个稀疏 MOE (Mixture of Experts) 模型中的路由选择过程。给定 个专家和
张 NPU 卡,需要按照以下三步规则,从所有专家中选出最终的
个作为路由目标:
-
分组与代表选举:
- 首先,将
个专家平均分配到
张卡上,形成
个组。
- 在每个组内部,选出概率最高的专家作为该组的“代表”。
- 首先,将
-
顶层路由(选组):
- 对所有组的“代表”进行排序(按概率降序,概率相同时组号小的优先)。
- 从排序后的结果中,选取前
个组。
-
二层路由(选专家):
- 将上一步选出的
个组中的所有专家汇集起来。
- 对这些专家进行排序(按概率降序,概率相同时专家编号小的优先)。
- 选取前
位专家作为最终结果。
- 将上一步选出的
同时,需要处理一些约束和异常情况,如无法平均分组、参数不合理等。
解题思路
这是一个多阶段的排序和筛选问题。解题的关键在于清晰地实现每一步的逻辑,并正确处理排序的复合条件。
-
输入与异常检查
- 首先读取四个整数
和
个专家的概率。
- 立即进行异常检查,这是最优先的步骤:
- 无法整除:如果
,则无法平均分组,输出
error
。 - 选组数超限:如果
,要选择的组数比总组数还多,不合逻辑,输出
error
。 - 可选专家不足:每组专家数
。
个组总共有
个专家。如果这个数量小于
,即
,那么不可能选出
个专家,输出
error
。
- 无法整除:如果
- 如果任何一个条件满足,程序直接结束。
- 首先读取四个整数
-
分组并找出代表
- 计算每组的专家数量
。
- 创建一个数据结构(例如,一个列表或数组)来存储每个组的代表信息。这个结构需要包含:组的原始编号、代表的概率值、代表的专家编号。
- 遍历
个组(从组
到组
):
- 对于每个组
,其专家编号范围是
[i*g, (i+1)*g - 1]
。 - 在这个范围内寻找概率最高的专家,记录下其概率和编号。
- 将(组号
,最大概率,对应专家编号)存入代表列表中。
- 对于每个组
- 计算每组的专家数量
-
排序并选择Top-p组
- 对上一步生成的代表列表进行排序。排序规则为:
- 主要按概率降序排列。
- 如果概率相同,则按组的原始编号升序排列。
- 排序后,选取列表中的前
项。这
项对应的组就是我们接下来要处理的组。
- 对上一步生成的代表列表进行排序。排序规则为:
-
汇集、排序并选择Top-k专家
- 创建一个新的列表,用于存放候选专家。
- 遍历刚才选出的
个组。
- 对于每个被选中的组,将其包含的所有专家(及其概率和原始编号)都添加到候选专家列表中。
- 对这个候选专家列表进行排序。排序规则为:
- 主要按概率降序排列。
- 如果概率相同,则按专家的原始编号升序排列。
- 排序后,选取列表中的前
项。
-
输出结果
- 从最终选出的
个专家中,提取出他们的原始编号。
- 对这
个编号进行升序排序。
- 按格式要求,用空格分隔输出。
- 从最终选出的
代码
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
using namespace std;
struct Expert {
int id;
double prob;
};
struct GroupRep {
int id;
double max_prob;
int expert_id;
};
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, m, p, k;
cin >> n >> m >> p >> k;
if (n % m != 0 || p > m) {
cout << "error\n";
return 0;
}
int g = n / m;
if (1LL * p * g < k) {
cout << "error\n";
return 0;
}
vector<Expert> experts(n);
for (int i = 0; i < n; ++i) {
experts[i].id = i;
cin >> experts[i].prob;
}
vector<GroupRep> reps;
for (int i = 0; i < m; ++i) {
int start_idx = i * g;
int end_idx = start_idx + g;
int max_expert_id = start_idx;
double max_prob = experts[start_idx].prob;
for (int j = start_idx + 1; j < end_idx; ++j) {
if (experts[j].prob > max_prob) {
max_prob = experts[j].prob;
max_expert_id = j;
}
}
reps.push_back({i, max_prob, max_expert_id});
}
sort(reps.begin(), reps.end(), [](const GroupRep& a, const GroupRep& b) {
if (a.max_prob != b.max_prob) {
return a.max_prob > b.max_prob;
}
return a.id < b.id;
});
vector<Expert> candidate_experts;
for (int i = 0; i < p; ++i) {
int group_id = reps[i].id;
int start_idx = group_id * g;
int end_idx = start_idx + g;
for (int j = start_idx; j < end_idx; ++j) {
candidate_experts.push_back(experts[j]);
}
}
sort(candidate_experts.begin(), candidate_experts.end(), [](const Expert& a, const Expert& b) {
if (a.prob != b.prob) {
return a.prob > b.prob;
}
return a.id < b.id;
});
vector<int> result_ids;
for (int i = 0; i < k; ++i) {
result_ids.push_back(candidate_experts[i].id);
}
sort(result_ids.begin(), result_ids.end());
for (int i = 0; i < k; ++i) {
cout << result_ids[i] << (i == k - 1 ? "" : " ");
}
cout << "\n";
return 0;
}
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Scanner;
public class Main {
static class Expert {
int id;
double prob;
Expert(int id, double prob) {
this.id = id;
this.prob = prob;
}
}
static class GroupRep {
int id;
double maxProb;
int expertId;
GroupRep(int id, double maxProb, int expertId) {
this.id = id;
this.maxProb = maxProb;
this.expertId = expertId;
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
int p = sc.nextInt();
int k = sc.nextInt();
if (n % m != 0 || p > m) {
System.out.println("error");
return;
}
int g = n / m;
if ((long) p * g < k) {
System.out.println("error");
return;
}
List<Expert> experts = new ArrayList<>();
for (int i = 0; i < n; i++) {
experts.add(new Expert(i, sc.nextDouble()));
}
List<GroupRep> reps = new ArrayList<>();
for (int i = 0; i < m; i++) {
int startIdx = i * g;
int maxExpertId = startIdx;
double maxProb = experts.get(startIdx).prob;
for (int j = startIdx + 1; j < startIdx + g; j++) {
if (experts.get(j).prob > maxProb) {
maxProb = experts.get(j).prob;
maxExpertId = j;
}
}
reps.add(new GroupRep(i, maxProb, maxExpertId));
}
reps.sort((a, b) -> {
if (a.maxProb != b.maxProb) {
return Double.compare(b.maxProb, a.maxProb);
}
return Integer.compare(a.id, b.id);
});
List<Expert> candidateExperts = new ArrayList<>();
for (int i = 0; i < p; i++) {
int groupId = reps.get(i).id;
int startIdx = groupId * g;
for (int j = 0; j < g; j++) {
candidateExperts.add(experts.get(startIdx + j));
}
}
candidateExperts.sort((a, b) -> {
if (a.prob != b.prob) {
return Double.compare(b.prob, a.prob);
}
return Integer.compare(a.id, b.id);
});
List<Integer> resultIds = new ArrayList<>();
for (int i = 0; i < k; i++) {
resultIds.add(candidateExperts.get(i).id);
}
Collections.sort(resultIds);
for (int i = 0; i < k; i++) {
System.out.print(resultIds.get(i) + (i == k - 1 ? "" : " "));
}
System.out.println();
}
}
def solve():
try:
n, m, p, k = map(int, input().split())
probs = list(map(float, input().split()))
except (IOError, ValueError):
# 处理可能的空输入行
return
if n % m != 0 or p > m:
print("error")
return
g = n // m
if p * g < k:
print("error")
return
experts = [{'id': i, 'prob': prob} for i, prob in enumerate(probs)]
reps = []
for i in range(m):
start_idx = i * g
group_experts = experts[start_idx : start_idx + g]
# 找到组内概率最大的专家
# python的max函数在处理(value, key)元组时,如果value相同,会比较key,
# 但我们需要的是编号最小的,所以用-id来反转编号的比较
# 更好的方式是写一个稳定的查找
max_prob_expert = group_experts[0]
for expert in group_experts[1:]:
if expert['prob'] > max_prob_expert['prob']:
max_prob_expert = expert
reps.append({'id': i, 'max_prob': max_prob_expert['prob'], 'expert_id': max_prob_expert['id']})
reps.sort(key=lambda x: (-x['max_prob'], x['id']))
candidate_experts = []
for i in range(p):
group_id = reps[i]['id']
start_idx = group_id * g
candidate_experts.extend(experts[start_idx : start_idx + g])
candidate_experts.sort(key=lambda x: (-x['prob'], x['id']))
result_ids = [candidate_experts[i]['id'] for i in range(k)]
result_ids.sort()
print(' '.join(map(str, result_ids)))
solve()
算法及复杂度
- 算法:模拟、排序
- 时间复杂度:
,其中
。
:读取输入和寻找各组代表。
:对
个组代表进行排序。
:对
个候选专家进行排序,这是最主要的时间开销。
:对最终结果进行排序。
- 空间复杂度:
,主要用于存储所有专家的信息和候选专家的列表。