结构化剪枝后的分类预测

题目分析

给定样本矩阵 )、线性分类器权重矩阵 )和剪枝比例 ,模拟"行剪枝"过程:按 L1 范数从小到大移除 中最不重要的 行及 中对应的列,然后用剪枝后的矩阵做线性变换 + Softmax 预测每个样本的类别。

思路

矩阵剪枝 + 线性分类模拟

按题意逐步实现即可,关键是把每一步的细节处理正确:

  1. 计算剪枝行数 。特判:当 时,强制 (至少剪一行)。
  1. 计算每行 L1 范数:对 的第 行,。按 L1 从小到大排序,移除最小的 行。
  1. 特征对齐 移除的行索引对应 要移除的列索引。保留剩余特征,得到 )和 )。
  1. 线性变换,得到 的得分矩阵。
  1. Stable Softmax 预测:对每个样本的得分向量,先减去最大值再求 ,取 argmax 作为预测类别(相同取最左)。注意:由于 是单调递增函数,减去最大值后 argmax 不变,所以实际上直接对原始得分取 argmax 即可,Softmax 不影响结果。

以样例验证: ,但 ,所以 。三行 L1 范数为 ,移除第 1 行(索引从 0 开始)。剪枝后 保留第 0、2 列, 保留第 0、2 行。计算 后取 argmax,得到 ,与预期一致。

复杂度

  • 时间复杂度:,主要瓶颈是矩阵乘法
  • 空间复杂度:,存储输入矩阵

代码

import sys
import math

def main():
    data = sys.stdin.read().split()
    idx = 0
    n = int(data[idx]); idx += 1
    d = int(data[idx]); idx += 1
    c = int(data[idx]); idx += 1

    X = []
    for i in range(n):
        row = []
        for j in range(d):
            row.append(float(data[idx])); idx += 1
        X.append(row)

    W = []
    for i in range(d):
        row = []
        for j in range(c):
            row.append(float(data[idx])); idx += 1
        W.append(row)

    ratio = float(data[idx])

    # 计算剪枝行数
    k = int(math.floor(ratio * d))
    if ratio > 0 and k == 0:
        k = 1

    # 按 L1 范数排序,找出要移除的行
    l1 = [(sum(abs(W[i][j]) for j in range(c)), i) for i in range(d)]
    l1.sort()
    removed = set(l1[i][1] for i in range(k))
    kept = [i for i in range(d) if i not in removed]

    # 剪枝:保留对应行/列
    W_p = [W[i] for i in kept]
    X_p = [[X[i][j] for j in kept] for i in range(n)]

    # 矩阵乘法 + argmax
    d_new = len(kept)
    res = []
    for i in range(n):
        scores = []
        for j in range(c):
            s = 0.0
            for t in range(d_new):
                s += X_p[i][t] * W_p[t][j]
            scores.append(s)
        # argmax(取最左)
        best = 0
        for j in range(1, c):
            if scores[j] > scores[best]:
                best = j
        res.append(str(best))

    print(' '.join(res))

if __name__ == '__main__':
    main()