题目链接

对称INT8量化方案

题目描述

在嵌入式设备上运行神经网络推理时,需要将浮点参数压缩为低比特整数以减少开销。本题要求实现一种对称 INT8 量化方案:

  1. 激活矩阵 ():按行量化(per-sample)。对第 行,缩放因子 。量化值
  2. 参数矩阵 ():按列量化(per-channel)。对第 列,缩放因子 。量化值
  3. 舍入规则:采用 Python round() 的银行家舍入(四舍六入五取偶)。
  4. 计算与还原:在整数域计算 ,最后还原为浮点数

解题思路

  1. 数据读取:依次读取矩阵 的维度及元素。
  2. 计算缩放因子
    • 对于 的每一行,寻找绝对值的最大值并除以 得到 。如果最大值为 ,则
    • 对于 的每一列,寻找绝对值的最大值并除以 得到 。如果最大值为 ,则
  3. 执行量化
    • 遍历矩阵元素进行量化。注意处理 的情况(此时量化值为 )。
    • 银行家舍入实现:这是本题的难点。需要确保在 C++ 和 Java 中实现与 Python round() 一致的逻辑(即:小数部分非 时正常四舍五入;恰好为 时,舍入到最近的偶数)。
  4. 矩阵乘法:在整数域(使用 intlong long)计算 的乘积。
  5. 还原与输出:将结果乘以对应的 ,并格式化输出保留两位小数。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#include <iomanip>

using namespace std;

// 银行家舍入:四舍六入五取偶
long long banker_round(double x) {
    double i = floor(x);
    double f = x - i;
    if (f < 0.5) return (long long)i;
    if (f > 0.5) return (long long)(i + 1);
    // 恰好为 0.5,取最近的偶数
    long long li = (long long)i;
    if (abs(li) % 2 == 0) return li;
    return li + 1;
}

int main() {
    int m, k_a;
    cin >> m >> k_a;
    vector<vector<double>> a(m, vector<double>(k_a));
    vector<double> sa(m);
    for (int i = 0; i < m; ++i) {
        double max_val = 0;
        for (int j = 0; j < k_a; ++j) {
            cin >> a[i][j];
            max_val = max(max_val, abs(a[i][j]));
        }
        sa[i] = max_val / 127.0;
    }

    vector<vector<int>> qa(m, vector<int>(k_a));
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < k_a; ++j) {
            if (sa[i] == 0) qa[i][j] = 0;
            else {
                long long val = banker_round(a[i][j] / sa[i]);
                qa[i][j] = (int)max(-127LL, min(127LL, val));
            }
        }
    }

    int k_b, n;
    cin >> k_b >> n;
    vector<vector<double>> b(k_b, vector<double>(n));
    for (int i = 0; i < k_b; ++i) {
        for (int j = 0; j < n; ++j) {
            cin >> b[i][j];
        }
    }

    vector<double> sb(n);
    vector<vector<int>> qb(k_b, vector<int>(n));
    for (int j = 0; j < n; ++j) {
        double max_val = 0;
        for (int i = 0; i < k_b; ++i) {
            max_val = max(max_val, abs(b[i][j]));
        }
        sb[j] = max_val / 127.0;
        for (int i = 0; i < k_b; ++i) {
            if (sb[j] == 0) qb[i][j] = 0;
            else {
                long long val = banker_round(b[i][j] / sb[j]);
                qb[i][j] = (int)max(-127LL, min(127LL, val));
            }
        }
    }

    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            long long sum = 0;
            for (int l = 0; l < k_a; ++l) {
                sum += (long long)qa[i][l] * qb[l][j];
            }
            double res = (double)sum * sa[i] * sb[j];
            // 使用 fixed 和 setprecision(2) 保留两位小数
            cout << fixed << setprecision(2) << res << (j == n - 1 ? "" : " ");
        }
        cout << endl;
    }

    return 0;
}
import java.util.*;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.math.RoundingMode;

public class Main {
    // 限制量化值在 [-127, 127] 范围内
    private static int clip(double x) {
        long rounded = (long) Math.rint(x); // Math.rint 实现银行家舍入
        if (rounded > 127) return 127;
        if (rounded < -127) return -127;
        return (int) rounded;
    }

    public static void main(String[] args) {
        // 使用 Locale.US 确保小数点为 "."
        Scanner sc = new Scanner(System.in).useLocale(Locale.US);
        
        // --- 矩阵 A 量化 ---
        int m = sc.nextInt();
        int k = sc.nextInt();
        double[][] a = new double[m][k];
        double[] sa = new double[m];
        int[][] qa = new int[m][k];

        for (int i = 0; i < m; i++) {
            double maxAbs = 0;
            for (int j = 0; j < k; j++) {
                a[i][j] = sc.nextDouble();
                maxAbs = Math.max(maxAbs, Math.abs(a[i][j]));
            }
            sa[i] = maxAbs / 127.0;
            for (int j = 0; j < k; j++) {
                if (sa[i] == 0) qa[i][j] = 0;
                else qa[i][j] = clip(a[i][j] / sa[i]);
            }
        }

        // --- 矩阵 B 量化 ---
        int kb = sc.nextInt();
        int n = sc.nextInt();
        double[][] b = new double[kb][n];
        for (int i = 0; i < kb; i++) {
            for (int j = 0; j < n; j++) {
                b[i][j] = sc.nextDouble();
            }
        }

        double[] sb = new double[n];
        int[][] qb = new int[kb][n];
        for (int j = 0; j < n; j++) {
            double maxAbs = 0;
            for (int i = 0; i < kb; i++) {
                maxAbs = Math.max(maxAbs, Math.abs(b[i][j]));
            }
            sb[j] = maxAbs / 127.0;
            for (int i = 0; i < kb; i++) {
                if (sb[j] == 0) qb[i][j] = 0;
                else qb[i][j] = clip(b[i][j] / sb[j]);
            }
        }

        // --- 矩阵乘法、还原与输出 ---
        // 配置输出格式为银行家舍入(HALF_EVEN)
        DecimalFormat df = new DecimalFormat("0.00");
        df.setRoundingMode(RoundingMode.HALF_EVEN);
        DecimalFormatSymbols symbols = new DecimalFormatSymbols(Locale.US);
        df.setDecimalFormatSymbols(symbols);

        for (int i = 0; i < m; i++) {
            StringBuilder sbLine = new StringBuilder();
            for (int j = 0; j < n; j++) {
                long sum = 0;
                for (int l = 0; l < k; l++) {
                    sum += (long) qa[i][l] * qb[l][j];
                }
                double res = sum * sa[i] * sb[j];
                sbLine.append(df.format(res));
                if (j < n - 1) sbLine.append(" ");
            }
            System.out.println(sbLine.toString());
        }
    }
}
def solve():
    # 读取矩阵 A
    line_a = input().split()
    if not line_a: return
    m, k = map(int, line_a)
    a = []
    sa = []
    qa = []
    
    for i in range(m):
        row = list(map(float, input().split()))
        a.append(row)
        max_abs = max([abs(x) for x in row])
        s = max_abs / 127.0
        sa.append(s)
        q_row = []
        for x in row:
            if s == 0:
                q_row.append(0)
            else:
                # Python 3 的 round() 默认就是银行家舍入
                val = round(x / s)
                q_row.append(max(-127, min(127, val)))
        qa.append(q_row)

    # 读取矩阵 B
    line_b = input().split()
    if not line_b: return
    kb, n = map(int, line_b)
    b = []
    for _ in range(kb):
        b.append(list(map(float, input().split())))

    sb = []
    qb = [[0] * n for _ in range(kb)]
    for j in range(n):
        col = [b[i][j] for i in range(kb)]
        max_abs = max([abs(x) for x in col])
        s = max_abs / 127.0
        sb.append(s)
        for i in range(kb):
            if s == 0:
                qb[i][j] = 0
            else:
                val = round(b[i][j] / s)
                qb[i][j] = max(-127, min(127, val))

    # 矩阵乘法
    for i in range(m):
        res_row = []
        for j in range(n):
            sum_val = 0
            for l in range(k):
                sum_val += qa[i][l] * qb[l][j]
            # 还原
            fp_val = sum_val * sa[i] * sb[j]
            res_row.append(format(fp_val, '.2f'))
        print(" ".join(res_row))

solve()

算法及复杂度

  • 算法:矩阵计算 + 对称量化处理
  • 时间复杂度:。量化过程耗时 ,核心计算量在于整数域的矩阵乘法。
  • 空间复杂度:。需要存储矩阵元素、量化后的整数矩阵以及缩放因子。