MOE Top-k 路由
题意
稀疏 MOE 模型有 个专家(编号
到
),均匀分布在
张 NPU 卡上,每张卡上的专家编号连续。为了减少跨卡通信,路由时最多只选
张卡。
算法分三步:
- 分组代表:每组取概率最大的专家作为代表(概率相同取编号小的)
- 选组:所有组按代表概率降序排序(相同概率取组号小的),选前
组
- 选专家:从这
组的全部专家中,按概率降序(相同概率取编号小的)选前
个
输出这 个专家的编号,升序输出。
三种情况输出 error: 不能被
整除、
、可选专家数
。
思路
这题没有算法上的难度,就是个模拟题,关键是把三步流程理清楚,排序的比较规则写对。
先处理边界情况。 说明没法均分,
说明选的卡比总卡数还多,
(
是每组大小)说明选出来的专家总数不够
个——这三种都直接输出
error。
然后模拟三步就行了。
第一步,把 个专家按每组
个切分,每组里找概率最大的。这一步的结果是每组一个代表概率。
第二步,对 个组按
排序,取前
组。
第三步,把这 组里的所有专家拿出来,按
排序,取前
个。最后把选出来的
个编号升序输出。
注意排序的 tie-breaking 规则:概率相同时,不管是选组还是选专家,都是编号/组号小的优先。
时间复杂度 ,瓶颈在排序。空间
。
代码
import sys
def solve():
data = sys.stdin.read().split()
idx = 0
n = int(data[idx]); idx += 1
m = int(data[idx]); idx += 1
p = int(data[idx]); idx += 1
k = int(data[idx]); idx += 1
probs = []
for i in range(n):
probs.append(float(data[idx])); idx += 1
if n % m != 0 or p > m:
print("error")
return
g = n // m
if p * g < k:
print("error")
return
# 第一步:每组选代表
groups = []
for i in range(m):
start = i * g
best_prob = -1
for j in range(start, start + g):
if probs[j] > best_prob:
best_prob = probs[j]
groups.append((best_prob, i))
# 第二步:选前 p 组
groups.sort(key=lambda x: (-x[0], x[1]))
selected = [groups[i][1] for i in range(p)]
# 第三步:从选中组里选前 k 个专家
candidates = []
for gi in selected:
start = gi * g
for j in range(start, start + g):
candidates.append((probs[j], j))
candidates.sort(key=lambda x: (-x[0], x[1]))
result = sorted(candidates[i][1] for i in range(k))
print(' '.join(map(str, result)))
solve()
#include <bits/stdc++.h>
using namespace std;
int main(){
int n, m, p, k;
scanf("%d%d%d%d", &n, &m, &p, &k);
vector<double> prob(n);
for(int i = 0; i < n; i++) scanf("%lf", &prob[i]);
if(n % m != 0 || p > m){
puts("error");
return 0;
}
int g = n / m;
if((long long)p * g < k){
puts("error");
return 0;
}
// 第一步:每组选代表
vector<pair<double,int>> groups(m);
for(int i = 0; i < m; i++){
int s = i * g;
double best = -1;
for(int j = s; j < s + g; j++)
if(prob[j] > best) best = prob[j];
groups[i] = {best, i};
}
// 第二步:选前 p 组
sort(groups.begin(), groups.end(), [](auto& a, auto& b){
if(a.first != b.first) return a.first > b.first;
return a.second < b.second;
});
// 第三步:从选中组里选前 k 个专家
vector<pair<double,int>> cands;
for(int i = 0; i < p; i++){
int gi = groups[i].second, s = gi * g;
for(int j = s; j < s + g; j++)
cands.push_back({prob[j], j});
}
sort(cands.begin(), cands.end(), [](auto& a, auto& b){
if(a.first != b.first) return a.first > b.first;
return a.second < b.second;
});
vector<int> res;
for(int i = 0; i < k; i++) res.push_back(cands[i].second);
sort(res.begin(), res.end());
for(int i = 0; i < k; i++){
if(i) printf(" ");
printf("%d", res[i]);
}
puts("");
return 0;
}

京公网安备 11010502036488号