题目链接

模型量化最小误差

题目描述

在一台边缘设备上部署神经网络,需要对权重矩阵进行量化。

  • 网络共有 层,每层有 个实数权重。
  • 每一层必须选择一个量化位宽 ,且
  • 所有层选择的位宽之和不能超过
  • 量化过程:
    1. 放大并取整:
    2. 还原:
    3. 误差定义:该层所有权重的 之和。
  • 目标:在总位宽预算内,使全网总误差(各层误差之和)最小。
  • 输出:最小总误差 后向下取整的结果。

输入:

  • 第一行:
  • 接下来 行:每行 个实数,表示对应层的权重。

输出:

  • 一个整数,为最小总误差 后向下取整的结果。

解题思路

这是一个典型的动态规划问题,类似于分组背包问题。

  1. 预处理每层的误差: 对于每一层 ),分别计算选择位宽 时产生的总误差 根据题目描述,量化采用的是向零取整或直接截断的 ,还原时除以

  2. 状态定义: 设 表示前 层在总位宽预算为 时所能达到的最小全网误差。

    • 的范围:
    • 的范围:
  3. 状态转移方程,其中 。 初始状态:,其余

  4. 结果处理: 最终答案为 。 输出为 。为了保证精度,建议在处理误差时使用高精度浮点数,并在最后取整时加上微小的

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

typedef long double LD;

int main() {
    int n, h, q_max;
    cin >> n >> h >> q_max;

    // e[i][q] 存储第 i 层选择位宽 q 时的总误差
    vector<vector<LD>> e(n, vector<LD>(9, 0.0));
    int bits[] = {2, 4, 8};

    for (int i = 0; i < n; ++i) {
        vector<double> weights(h);
        for (int j = 0; j < h; ++j) cin >> weights[j];

        for (int q : bits) {
            LD layer_err = 0;
            double p = pow(2, q);
            for (int j = 0; j < h; ++j) {
                double wq = floor(weights[j] * p);
                double wr = wq / p;
                layer_err += abs((LD)weights[j] - (LD)wr);
            }
            e[i][q] = layer_err;
        }
    }

    const LD INF = 1e18;
    vector<vector<LD>> dp(n + 1, vector<LD>(q_max + 1, INF));
    dp[0][0] = 0;

    for (int i = 1; i <= n; ++i) {
        for (int j = 0; j <= q_max; ++j) {
            for (int q : bits) {
                if (j >= q && dp[i - 1][j - q] != INF) {
                    dp[i][j] = min(dp[i][j], dp[i - 1][j - q] + e[i - 1][q]);
                }
            }
        }
    }

    LD min_total_err = INF;
    for (int j = 0; j <= q_max; ++j) {
        min_total_err = min(min_total_err, dp[n][j]);
    }

    cout << (long long)floor(min_total_err * 100 + 1e-9) << endl;

    return 0;
}
import java.util.Scanner;
import java.util.Arrays;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int h = sc.nextInt();
        int qMax = sc.nextInt();

        double[][] e = new double[n][9];
        int[] bits = {2, 4, 8};

        for (int i = 0; i < n; i++) {
            double[] weights = new double[h];
            for (int j = 0; j < h; j++) {
                weights[j] = sc.nextDouble();
            }
            for (int q : bits) {
                double layerErr = 0;
                double p = Math.pow(2, q);
                for (int j = 0; j < h; j++) {
                    double wq = Math.floor(weights[j] * p);
                    double wr = wq / p;
                    layerErr += Math.abs(weights[j] - wr);
                }
                e[i][q] = layerErr;
            }
        }

        double INF = 1e18;
        double[][] dp = new double[n + 1][qMax + 1];
        for (int i = 0; i <= n; i++) {
            Arrays.fill(dp[i], INF);
        }
        dp[0][0] = 0;

        for (int i = 1; i <= n; i++) {
            for (int j = 0; j <= qMax; j++) {
                for (int q : bits) {
                    if (j >= q && dp[i - 1][j - q] != INF) {
                        dp[i][j] = Math.min(dp[i][j], dp[i - 1][j - q] + e[i - 1][q]);
                    }
                }
            }
        }

        double minTotalErr = INF;
        for (int j = 0; j <= qMax; j++) {
            minTotalErr = Math.min(minTotalErr, dp[n][j]);
        }

        System.out.println((long) Math.floor(minTotalErr * 100 + 1e-9));
    }
}
import math

def solve():
    line1 = input().split()
    if not line1: return
    n, h, q_max = map(int, line1)

    e = [[0.0] * 9 for _ in range(n)]
    bits = [2, 4, 8]

    for i in range(n):
        weights = list(map(float, input().split()))
        for q in bits:
            layer_err = 0.0
            p = 2**q
            for w in weights:
                wq = math.floor(w * p)
                wr = wq / p
                layer_err += abs(w - wr)
            e[i][q] = layer_err

    inf = 1e18
    dp = [[inf] * (q_max + 1) for _ in range(n + 1)]
    dp[0][0] = 0.0

    for i in range(1, n + 1):
        for j in range(q_max + 1):
            for q in bits:
                if j >= q and dp[i - 1][j - q] != inf:
                    val = dp[i - 1][j - q] + e[i - 1][q]
                    if val < dp[i][j]:
                        dp[i][j] = val

    min_err = inf
    for j in range(q_max + 1):
        if dp[n][j] < min_err:
            min_err = dp[n][j]

    print(int(math.floor(min_err * 100 + 1e-9)))

if __name__ == "__main__":
    solve()

算法及复杂度

  • 算法:动态规划(分组背包思想)。
  • 时间复杂度:。预处理每层误差耗时 ,动态规划状态转移耗时
  • 空间复杂度:。用于存储状态转移数组。