MOE Top-k 路由

题意

稀疏 MOE 模型有 个专家(编号 ),均匀分布在 张 NPU 卡上,每张卡上的专家编号连续。为了减少跨卡通信,路由时最多只选 张卡。

算法分三步:

  1. 分组代表:每组取概率最大的专家作为代表(概率相同取编号小的)
  2. 选组:所有组按代表概率降序排序(相同概率取组号小的),选前
  3. 选专家:从这 组的全部专家中,按概率降序(相同概率取编号小的)选前

输出这 个专家的编号,升序输出。

三种情况输出 error 不能被 整除、、可选专家数

思路

这题没有算法上的难度,就是个模拟题,关键是把三步流程理清楚,排序的比较规则写对。

先处理边界情况。 说明没法均分, 说明选的卡比总卡数还多, 是每组大小)说明选出来的专家总数不够 个——这三种都直接输出 error

然后模拟三步就行了。

第一步,把 个专家按每组 个切分,每组里找概率最大的。这一步的结果是每组一个代表概率。

第二步,对 个组按 排序,取前 组。

第三步,把这 组里的所有专家拿出来,按 排序,取前 个。最后把选出来的 个编号升序输出。

注意排序的 tie-breaking 规则:概率相同时,不管是选组还是选专家,都是编号/组号小的优先。

时间复杂度 ,瓶颈在排序。空间

代码

import sys

def solve():
    data = sys.stdin.read().split()
    idx = 0
    n = int(data[idx]); idx += 1
    m = int(data[idx]); idx += 1
    p = int(data[idx]); idx += 1
    k = int(data[idx]); idx += 1

    probs = []
    for i in range(n):
        probs.append(float(data[idx])); idx += 1

    if n % m != 0 or p > m:
        print("error")
        return

    g = n // m
    if p * g < k:
        print("error")
        return

    # 第一步:每组选代表
    groups = []
    for i in range(m):
        start = i * g
        best_prob = -1
        for j in range(start, start + g):
            if probs[j] > best_prob:
                best_prob = probs[j]
        groups.append((best_prob, i))

    # 第二步:选前 p 组
    groups.sort(key=lambda x: (-x[0], x[1]))
    selected = [groups[i][1] for i in range(p)]

    # 第三步:从选中组里选前 k 个专家
    candidates = []
    for gi in selected:
        start = gi * g
        for j in range(start, start + g):
            candidates.append((probs[j], j))

    candidates.sort(key=lambda x: (-x[0], x[1]))
    result = sorted(candidates[i][1] for i in range(k))
    print(' '.join(map(str, result)))

solve()
#include <bits/stdc++.h>
using namespace std;

int main(){
    int n, m, p, k;
    scanf("%d%d%d%d", &n, &m, &p, &k);
    vector<double> prob(n);
    for(int i = 0; i < n; i++) scanf("%lf", &prob[i]);

    if(n % m != 0 || p > m){
        puts("error");
        return 0;
    }
    int g = n / m;
    if((long long)p * g < k){
        puts("error");
        return 0;
    }

    // 第一步:每组选代表
    vector<pair<double,int>> groups(m);
    for(int i = 0; i < m; i++){
        int s = i * g;
        double best = -1;
        for(int j = s; j < s + g; j++)
            if(prob[j] > best) best = prob[j];
        groups[i] = {best, i};
    }

    // 第二步:选前 p 组
    sort(groups.begin(), groups.end(), [](auto& a, auto& b){
        if(a.first != b.first) return a.first > b.first;
        return a.second < b.second;
    });

    // 第三步:从选中组里选前 k 个专家
    vector<pair<double,int>> cands;
    for(int i = 0; i < p; i++){
        int gi = groups[i].second, s = gi * g;
        for(int j = s; j < s + g; j++)
            cands.push_back({prob[j], j});
    }

    sort(cands.begin(), cands.end(), [](auto& a, auto& b){
        if(a.first != b.first) return a.first > b.first;
        return a.second < b.second;
    });

    vector<int> res;
    for(int i = 0; i < k; i++) res.push_back(cands[i].second);
    sort(res.begin(), res.end());
    for(int i = 0; i < k; i++){
        if(i) printf(" ");
        printf("%d", res[i]);
    }
    puts("");
    return 0;
}