题目链接

结构化剪枝后的分类预测

题目描述

在终端设备上部署深度学习模型时,为了减少计算量和内存占用,常常需要对网络进行剪枝。本题模拟了一个对线性分类器进行“结构化剪枝”的过程。

给定一批样本矩阵 列)、一层线性分类器的权重矩阵 列),以及一个剪枝比例 。你需要按照以下规则对权重矩阵 进行剪枝,并用剪枝后的模型对每个样本进行预测,最终输出每个样本的预测类别索引。

剪枝和预测流程:

  1. 计算剪枝指标: 对 的每一行计算其 范数(该行所有元素绝对值之和)。 范数越小的行被认为越不重要。

  2. 确定剪枝行数: 要剪掉的行数 按如下规则确定:

    • (向下取整)。
    • 特殊规则: 如果 且计算出的 ,则强制令 (即只要比例大于0,至少剪掉一行)。如果 ,则不剪枝 ()。
  3. 执行剪枝: 移除 范数最小的 行,得到新的权重矩阵 ,其形状为

  4. 特征对齐: 对应地,从样本矩阵 中删除与 中被移除的行具有相同索引的,得到新的样本矩阵 ,其形状为

  5. 计算线性输出: 计算剪枝后模型的输出分数矩阵 ,其大小为

  6. 预测结果: 对分数矩阵 的每一行,找出最大值所在的列索引(即 argmax)。如果存在多个最大值,取索引最小的那个。这一行的 n 个索引就是最终的预测结果。

输入描述:

  • 第一行:三个整数
  • 接下来 行:矩阵
  • 接下来 行:矩阵
  • 最后一行:一个浮点数

输出描述:

  • 一行 个整数,以空格分隔,表示每个样本的预测类别索引。

解题思路

本题的核心是严格遵循题目描述的步骤,对矩阵进行筛选和运算。

  1. 读取输入:读入 以及矩阵 ,最后读取剪枝比例

  2. 计算剪枝行数 :根据 计算出 ,并处理 时需强制 的特殊情况。

  3. 计算 L1 范数并排序

    • 创建一个数据结构(如 pair 或自定义对象数组)来存储每一行权重的信息,包含 (L1范数, 原始行索引)
    • 遍历 行,计算每一行的 范数,并与行索引一起存入上述结构中。
    • 对这个结构按 范数从小到大进行排序。
  4. 确定要移除的索引

    • 排序后,选取前 个元素。它们对应的原始行索引就是要从 中移除的行,也是要从 中移除的列。
    • 将这 个索引存入一个哈希集合(HashSetunordered_set)中,以便后续快速查找。
  5. 构建剪枝后的矩阵

    • 构建 :遍历 行(索引从 )。如果当前行索引不在待移除索引的集合中,则将该行加入到新的矩阵 中。
    • 构建 :这一步比较技巧。可以遍历 的所有列(索引从 )。如果当前列索引不在待移除索引的集合中,则将该列的所有 个元素加入到新的矩阵 中。实现时,可以先构建 的转置,最后再转置回来,或者直接按列构建。
  6. 矩阵乘法:实现标准的矩阵乘法

    • 的第 行第 列的元素 的第 行和 的第 列的点积计算得出:
  7. Argmax 预测

    • 遍历 行。
    • 对于每一行,找到最大值及其首次出现的索引。
    • 将找到的索引存入结果数组。
  8. 输出结果:打印结果数组。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <algorithm>
#include <iomanip>
#include <unordered_set>

using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, d, c;
    cin >> n >> d >> c;

    vector<vector<double>> X(n, vector<double>(d));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            cin >> X[i][j];
        }
    }

    vector<vector<double>> W(d, vector<double>(c));
    for (int i = 0; i < d; ++i) {
        for (int j = 0; j < c; ++j) {
            cin >> W[i][j];
        }
    }

    double ratio;
    cin >> ratio;

    int k = floor(ratio * d);
    if (ratio > 0 && k == 0) {
        k = 1;
    }

    if (k > 0) {
        vector<pair<double, int>> l1_norms;
        for (int i = 0; i < d; ++i) {
            double norm = 0;
            for (int j = 0; j < c; ++j) {
                norm += abs(W[i][j]);
            }
            l1_norms.push_back({norm, i});
        }

        sort(l1_norms.begin(), l1_norms.end());

        unordered_set<int> removed_indices;
        for (int i = 0; i < k; ++i) {
            removed_indices.insert(l1_norms[i].second);
        }

        vector<vector<double>> W_prime;
        for (int i = 0; i < d; ++i) {
            if (removed_indices.find(i) == removed_indices.end()) {
                W_prime.push_back(W[i]);
            }
        }
        W = W_prime;

        vector<vector<double>> X_prime(n);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < d; ++j) {
                if (removed_indices.find(j) == removed_indices.end()) {
                    X_prime[i].push_back(X[i][j]);
                }
            }
        }
        X = X_prime;
    }
    
    int d_prime = W.size();
    if (d_prime == 0) { // All features pruned
        for (int i = 0; i < n; ++i) {
            cout << 0 << (i == n - 1 ? "" : " ");
        }
        cout << endl;
        return 0;
    }

    vector<vector<double>> h(n, vector<double>(c, 0.0));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < c; ++j) {
            for (int p = 0; p < d_prime; ++p) {
                h[i][j] += X[i][p] * W[p][j];
            }
        }
    }

    for (int i = 0; i < n; ++i) {
        int max_idx = 0;
        for (int j = 1; j < c; ++j) {
            if (h[i][j] > h[i][max_idx]) {
                max_idx = j;
            }
        }
        cout << max_idx << (i == n - 1 ? "" : " ");
    }
    cout << endl;

    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int n = sc.nextInt();
        int d = sc.nextInt();
        int c = sc.nextInt();

        double[][] X = new double[n][d];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < d; j++) {
                X[i][j] = sc.nextDouble();
            }
        }

        double[][] W = new double[d][c];
        for (int i = 0; i < d; i++) {
            for (int j = 0; j < c; j++) {
                W[i][j] = sc.nextDouble();
            }
        }

        double ratio = sc.nextDouble();

        int k = (int) Math.floor(ratio * d);
        if (ratio > 0 && k == 0) {
            k = 1;
        }

        if (k > 0) {
            // Pair: {L1 norm, original index}
            double[][] l1Norms = new double[d][2];
            for (int i = 0; i < d; i++) {
                double norm = 0;
                for (int j = 0; j < c; j++) {
                    norm += Math.abs(W[i][j]);
                }
                l1Norms[i][0] = norm;
                l1Norms[i][1] = i;
            }

            Arrays.sort(l1Norms, Comparator.comparingDouble(a -> a[0]));

            Set<Integer> removedIndices = new HashSet<>();
            for (int i = 0; i < k; i++) {
                removedIndices.add((int) l1Norms[i][1]);
            }

            int d_prime = d - k;
            double[][] W_prime = new double[d_prime][c];
            double[][] X_prime = new double[n][d_prime];
            
            int current_row = 0;
            for (int i = 0; i < d; i++) {
                if (!removedIndices.contains(i)) {
                    W_prime[current_row++] = W[i];
                }
            }

            for(int i = 0; i < n; i++){
                int current_col = 0;
                for(int j = 0; j < d; j++){
                    if(!removedIndices.contains(j)){
                        X_prime[i][current_col++] = X[i][j];
                    }
                }
            }
            W = W_prime;
            X = X_prime;
        }

        int d_prime = X[0].length;
        if (d_prime == 0) {
            for (int i = 0; i < n; i++) {
                System.out.print(0 + (i == n - 1 ? "" : " "));
            }
            System.out.println();
            return;
        }


        double[][] h = new double[n][c];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < c; j++) {
                for (int p = 0; p < d_prime; p++) {
                    h[i][j] += X[i][p] * W[p][j];
                }
            }
        }

        for (int i = 0; i < n; i++) {
            int maxIdx = 0;
            for (int j = 1; j < c; j++) {
                if (h[i][j] > h[i][maxIdx]) {
                    maxIdx = j;
                }
            }
            System.out.print(maxIdx + (i == n - 1 ? "" : " "));
        }
        System.out.println();
    }
}
import math

# 读取维度
n, d, c = map(int, input().split())

# 读取矩阵X
X = []
for _ in range(n):
    X.append(list(map(float, input().split())))

# 读取矩阵W
W = []
for _ in range(d):
    W.append(list(map(float, input().split())))

# 读取剪枝比例
ratio = float(input())

# 计算剪枝行数 k
k = math.floor(ratio * d)
if ratio > 0 and k == 0:
    k = 1

if k > 0:
    # 计算每行的 L1 范数及其原始索引
    l1_norms = []
    for i in range(d):
        norm = sum(abs(val) for val in W[i])
        l1_norms.append((norm, i))

    # 按 L1 范数排序
    l1_norms.sort()

    # 获取要移除的 k 个索引
    removed_indices = {l1_norms[i][1] for i in range(k)}

    # 构建剪枝后的 W_prime
    W_prime = [W[i] for i in range(d) if i not in removed_indices]
    W = W_prime

    # 构建剪枝后的 X_prime
    X_prime = [[X[i][j] for j in range(d) if j not in removed_indices] for i in range(n)]
    X = X_prime


d_prime = len(W)
# 如果所有特征都被剪掉
if d_prime == 0:
    # 默认预测为类别0
    print(*([0] * n))
else:
    # 矩阵乘法 h = X * W
    h = [[0.0] * c for _ in range(n)]
    for i in range(n):
        for j in range(c):
            for p in range(d_prime):
                h[i][j] += X[i][p] * W[p][j]

    # Argmax 预测
    predictions = []
    for i in range(n):
        max_val = h[i][0]
        max_idx = 0
        for j in range(1, c):
            if h[i][j] > max_val:
                max_val = h[i][j]
                max_idx = j
        predictions.append(max_idx)
    
    print(*predictions)

算法及复杂度

  • 算法:本题是一个模拟题,严格按照题目描述的步骤进行:计算L1范数、排序、矩阵筛选、矩阵乘法、argmax
  • 时间复杂度
    • 读取数据:
    • 计算所有行的L1范数:
    • 对L1范数进行排序:
    • 构建剪枝后的矩阵
    • 矩阵乘法 :设剪枝后维度为 ,则复杂度为 ,最坏情况下是
    • Argmax预测:
    • 其中,矩阵乘法是整个算法中复杂度最高的部分,因此总体时间复杂度由它决定。
  • 空间复杂度。主要开销在于存储输入的矩阵 以及剪枝后和计算中的中间矩阵。