结构化剪枝后的分类预测
题目分析
给定样本矩阵 (
)、线性分类器权重矩阵
(
)和剪枝比例
,模拟"行剪枝"过程:按 L1 范数从小到大移除
中最不重要的
行及
中对应的列,然后用剪枝后的矩阵做线性变换 + Softmax 预测每个样本的类别。
思路
矩阵剪枝 + 线性分类模拟
按题意逐步实现即可,关键是把每一步的细节处理正确:
- 计算剪枝行数
:
。特判:当
且
时,强制
(至少剪一行)。
- 计算每行 L1 范数:对
的第
行,
。按 L1 从小到大排序,移除最小的
行。
- 特征对齐:
移除的行索引对应
要移除的列索引。保留剩余特征,得到
(
)和
(
)。
- 线性变换:
,得到
的得分矩阵。
- Stable Softmax 预测:对每个样本的得分向量,先减去最大值再求
,取 argmax 作为预测类别(相同取最左)。注意:由于
是单调递增函数,减去最大值后 argmax 不变,所以实际上直接对原始得分取 argmax 即可,Softmax 不影响结果。
以样例验证: ,
。
,但
,所以
。三行 L1 范数为
,移除第 1 行(索引从 0 开始)。剪枝后
保留第 0、2 列,
保留第 0、2 行。计算
后取 argmax,得到
,与预期一致。
复杂度
- 时间复杂度:
,主要瓶颈是矩阵乘法
- 空间复杂度:
,存储输入矩阵
代码
import sys
import math
def main():
data = sys.stdin.read().split()
idx = 0
n = int(data[idx]); idx += 1
d = int(data[idx]); idx += 1
c = int(data[idx]); idx += 1
X = []
for i in range(n):
row = []
for j in range(d):
row.append(float(data[idx])); idx += 1
X.append(row)
W = []
for i in range(d):
row = []
for j in range(c):
row.append(float(data[idx])); idx += 1
W.append(row)
ratio = float(data[idx])
# 计算剪枝行数
k = int(math.floor(ratio * d))
if ratio > 0 and k == 0:
k = 1
# 按 L1 范数排序,找出要移除的行
l1 = [(sum(abs(W[i][j]) for j in range(c)), i) for i in range(d)]
l1.sort()
removed = set(l1[i][1] for i in range(k))
kept = [i for i in range(d) if i not in removed]
# 剪枝:保留对应行/列
W_p = [W[i] for i in kept]
X_p = [[X[i][j] for j in kept] for i in range(n)]
# 矩阵乘法 + argmax
d_new = len(kept)
res = []
for i in range(n):
scores = []
for j in range(c):
s = 0.0
for t in range(d_new):
s += X_p[i][t] * W_p[t][j]
scores.append(s)
# argmax(取最左)
best = 0
for j in range(1, c):
if scores[j] > scores[best]:
best = j
res.append(str(best))
print(' '.join(res))
if __name__ == '__main__':
main()

京公网安备 11010502036488号