题目链接
题目描述
在终端设备上部署深度学习模型时,为了减少计算量和内存占用,常常需要对网络进行剪枝。本题模拟了一个对线性分类器进行“结构化剪枝”的过程。
给定一批样本矩阵 (
行
列)、一层线性分类器的权重矩阵
(
行
列),以及一个剪枝比例
。你需要按照以下规则对权重矩阵
进行剪枝,并用剪枝后的模型对每个样本进行预测,最终输出每个样本的预测类别索引。
剪枝和预测流程:
-
计算剪枝指标: 对
的每一行计算其
范数(该行所有元素绝对值之和)。
范数越小的行被认为越不重要。
-
确定剪枝行数: 要剪掉的行数
按如下规则确定:
(向下取整)。
- 特殊规则: 如果
且计算出的
,则强制令
(即只要比例大于0,至少剪掉一行)。如果
,则不剪枝 (
)。
-
执行剪枝: 移除
中
范数最小的
行,得到新的权重矩阵
,其形状为
。
-
特征对齐: 对应地,从样本矩阵
中删除与
中被移除的行具有相同索引的列,得到新的样本矩阵
,其形状为
。
-
计算线性输出: 计算剪枝后模型的输出分数矩阵
,其大小为
。
-
预测结果: 对分数矩阵
的每一行,找出最大值所在的列索引(即
argmax)。如果存在多个最大值,取索引最小的那个。这一行的n个索引就是最终的预测结果。
输入描述:
- 第一行:三个整数
。
- 接下来
行:矩阵
。
- 接下来
行:矩阵
。
- 最后一行:一个浮点数
。
输出描述:
- 一行
个整数,以空格分隔,表示每个样本的预测类别索引。
解题思路
本题的核心是严格遵循题目描述的步骤,对矩阵进行筛选和运算。
-
读取输入:读入
以及矩阵
和
,最后读取剪枝比例
。
-
计算剪枝行数
:根据
和
计算出
,并处理
且
时需强制
的特殊情况。
-
计算 L1 范数并排序:
- 创建一个数据结构(如 pair 或自定义对象数组)来存储每一行权重的信息,包含
(L1范数, 原始行索引)。 - 遍历
的
行,计算每一行的
范数,并与行索引一起存入上述结构中。
- 对这个结构按
范数从小到大进行排序。
- 创建一个数据结构(如 pair 或自定义对象数组)来存储每一行权重的信息,包含
-
确定要移除的索引:
- 排序后,选取前
个元素。它们对应的原始行索引就是要从
中移除的行,也是要从
中移除的列。
- 将这
个索引存入一个哈希集合(
HashSet或unordered_set)中,以便后续快速查找。
- 排序后,选取前
-
构建剪枝后的矩阵
和
:
- 构建
:遍历
的
行(索引从
到
)。如果当前行索引不在待移除索引的集合中,则将该行加入到新的矩阵
中。
- 构建
:这一步比较技巧。可以遍历
的所有列(索引从
到
)。如果当前列索引不在待移除索引的集合中,则将该列的所有
个元素加入到新的矩阵
中。实现时,可以先构建
的转置,最后再转置回来,或者直接按列构建。
- 构建
-
矩阵乘法:实现标准的矩阵乘法
。
的第
行第
列的元素
由
的第
行和
的第
列的点积计算得出:
。
-
Argmax 预测:
- 遍历
的
行。
- 对于每一行,找到最大值及其首次出现的索引。
- 将找到的索引存入结果数组。
- 遍历
-
输出结果:打印结果数组。
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <algorithm>
#include <iomanip>
#include <unordered_set>
using namespace std;
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int n, d, c;
cin >> n >> d >> c;
vector<vector<double>> X(n, vector<double>(d));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < d; ++j) {
cin >> X[i][j];
}
}
vector<vector<double>> W(d, vector<double>(c));
for (int i = 0; i < d; ++i) {
for (int j = 0; j < c; ++j) {
cin >> W[i][j];
}
}
double ratio;
cin >> ratio;
int k = floor(ratio * d);
if (ratio > 0 && k == 0) {
k = 1;
}
if (k > 0) {
vector<pair<double, int>> l1_norms;
for (int i = 0; i < d; ++i) {
double norm = 0;
for (int j = 0; j < c; ++j) {
norm += abs(W[i][j]);
}
l1_norms.push_back({norm, i});
}
sort(l1_norms.begin(), l1_norms.end());
unordered_set<int> removed_indices;
for (int i = 0; i < k; ++i) {
removed_indices.insert(l1_norms[i].second);
}
vector<vector<double>> W_prime;
for (int i = 0; i < d; ++i) {
if (removed_indices.find(i) == removed_indices.end()) {
W_prime.push_back(W[i]);
}
}
W = W_prime;
vector<vector<double>> X_prime(n);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < d; ++j) {
if (removed_indices.find(j) == removed_indices.end()) {
X_prime[i].push_back(X[i][j]);
}
}
}
X = X_prime;
}
int d_prime = W.size();
if (d_prime == 0) { // All features pruned
for (int i = 0; i < n; ++i) {
cout << 0 << (i == n - 1 ? "" : " ");
}
cout << endl;
return 0;
}
vector<vector<double>> h(n, vector<double>(c, 0.0));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < c; ++j) {
for (int p = 0; p < d_prime; ++p) {
h[i][j] += X[i][p] * W[p][j];
}
}
}
for (int i = 0; i < n; ++i) {
int max_idx = 0;
for (int j = 1; j < c; ++j) {
if (h[i][j] > h[i][max_idx]) {
max_idx = j;
}
}
cout << max_idx << (i == n - 1 ? "" : " ");
}
cout << endl;
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int d = sc.nextInt();
int c = sc.nextInt();
double[][] X = new double[n][d];
for (int i = 0; i < n; i++) {
for (int j = 0; j < d; j++) {
X[i][j] = sc.nextDouble();
}
}
double[][] W = new double[d][c];
for (int i = 0; i < d; i++) {
for (int j = 0; j < c; j++) {
W[i][j] = sc.nextDouble();
}
}
double ratio = sc.nextDouble();
int k = (int) Math.floor(ratio * d);
if (ratio > 0 && k == 0) {
k = 1;
}
if (k > 0) {
// Pair: {L1 norm, original index}
double[][] l1Norms = new double[d][2];
for (int i = 0; i < d; i++) {
double norm = 0;
for (int j = 0; j < c; j++) {
norm += Math.abs(W[i][j]);
}
l1Norms[i][0] = norm;
l1Norms[i][1] = i;
}
Arrays.sort(l1Norms, Comparator.comparingDouble(a -> a[0]));
Set<Integer> removedIndices = new HashSet<>();
for (int i = 0; i < k; i++) {
removedIndices.add((int) l1Norms[i][1]);
}
int d_prime = d - k;
double[][] W_prime = new double[d_prime][c];
double[][] X_prime = new double[n][d_prime];
int current_row = 0;
for (int i = 0; i < d; i++) {
if (!removedIndices.contains(i)) {
W_prime[current_row++] = W[i];
}
}
for(int i = 0; i < n; i++){
int current_col = 0;
for(int j = 0; j < d; j++){
if(!removedIndices.contains(j)){
X_prime[i][current_col++] = X[i][j];
}
}
}
W = W_prime;
X = X_prime;
}
int d_prime = X[0].length;
if (d_prime == 0) {
for (int i = 0; i < n; i++) {
System.out.print(0 + (i == n - 1 ? "" : " "));
}
System.out.println();
return;
}
double[][] h = new double[n][c];
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int p = 0; p < d_prime; p++) {
h[i][j] += X[i][p] * W[p][j];
}
}
}
for (int i = 0; i < n; i++) {
int maxIdx = 0;
for (int j = 1; j < c; j++) {
if (h[i][j] > h[i][maxIdx]) {
maxIdx = j;
}
}
System.out.print(maxIdx + (i == n - 1 ? "" : " "));
}
System.out.println();
}
}
import math
# 读取维度
n, d, c = map(int, input().split())
# 读取矩阵X
X = []
for _ in range(n):
X.append(list(map(float, input().split())))
# 读取矩阵W
W = []
for _ in range(d):
W.append(list(map(float, input().split())))
# 读取剪枝比例
ratio = float(input())
# 计算剪枝行数 k
k = math.floor(ratio * d)
if ratio > 0 and k == 0:
k = 1
if k > 0:
# 计算每行的 L1 范数及其原始索引
l1_norms = []
for i in range(d):
norm = sum(abs(val) for val in W[i])
l1_norms.append((norm, i))
# 按 L1 范数排序
l1_norms.sort()
# 获取要移除的 k 个索引
removed_indices = {l1_norms[i][1] for i in range(k)}
# 构建剪枝后的 W_prime
W_prime = [W[i] for i in range(d) if i not in removed_indices]
W = W_prime
# 构建剪枝后的 X_prime
X_prime = [[X[i][j] for j in range(d) if j not in removed_indices] for i in range(n)]
X = X_prime
d_prime = len(W)
# 如果所有特征都被剪掉
if d_prime == 0:
# 默认预测为类别0
print(*([0] * n))
else:
# 矩阵乘法 h = X * W
h = [[0.0] * c for _ in range(n)]
for i in range(n):
for j in range(c):
for p in range(d_prime):
h[i][j] += X[i][p] * W[p][j]
# Argmax 预测
predictions = []
for i in range(n):
max_val = h[i][0]
max_idx = 0
for j in range(1, c):
if h[i][j] > max_val:
max_val = h[i][j]
max_idx = j
predictions.append(max_idx)
print(*predictions)
算法及复杂度
- 算法:本题是一个模拟题,严格按照题目描述的步骤进行:计算L1范数、排序、矩阵筛选、矩阵乘法、
argmax。 - 时间复杂度:
。
- 读取数据:
。
- 计算所有行的L1范数:
。
- 对L1范数进行排序:
。
- 构建剪枝后的矩阵
和
:
。
- 矩阵乘法
:设剪枝后维度为
,则复杂度为
,最坏情况下是
。
- Argmax预测:
。
- 其中,矩阵乘法是整个算法中复杂度最高的部分,因此总体时间复杂度由它决定。
- 读取数据:
- 空间复杂度:
。主要开销在于存储输入的矩阵
和
以及剪枝后和计算中的中间矩阵。

京公网安备 11010502036488号