题目链接
题目描述
在嵌入式设备上运行神经网络推理时,需要将浮点参数压缩为低比特整数以减少开销。本题要求实现一种对称 INT8 量化方案:
- 激活矩阵
(
):按行量化(per-sample)。对第
行,缩放因子
。量化值
。
- 参数矩阵
(
):按列量化(per-channel)。对第
列,缩放因子
。量化值
。
- 舍入规则:采用 Python
round()的银行家舍入(四舍六入五取偶)。 - 计算与还原:在整数域计算
,最后还原为浮点数
。
解题思路
- 数据读取:依次读取矩阵
和
的维度及元素。
- 计算缩放因子:
- 对于
的每一行,寻找绝对值的最大值并除以
得到
。如果最大值为
,则
。
- 对于
的每一列,寻找绝对值的最大值并除以
得到
。如果最大值为
,则
。
- 对于
- 执行量化:
- 遍历矩阵元素进行量化。注意处理
的情况(此时量化值为
)。
- 银行家舍入实现:这是本题的难点。需要确保在 C++ 和 Java 中实现与 Python
round()一致的逻辑(即:小数部分非时正常四舍五入;恰好为
时,舍入到最近的偶数)。
- 遍历矩阵元素进行量化。注意处理
- 矩阵乘法:在整数域(使用
int或long long)计算和
的乘积。
- 还原与输出:将结果乘以对应的
和
,并格式化输出保留两位小数。
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#include <iomanip>
using namespace std;
// 银行家舍入:四舍六入五取偶
long long banker_round(double x) {
double i = floor(x);
double f = x - i;
if (f < 0.5) return (long long)i;
if (f > 0.5) return (long long)(i + 1);
// 恰好为 0.5,取最近的偶数
long long li = (long long)i;
if (abs(li) % 2 == 0) return li;
return li + 1;
}
int main() {
int m, k_a;
cin >> m >> k_a;
vector<vector<double>> a(m, vector<double>(k_a));
vector<double> sa(m);
for (int i = 0; i < m; ++i) {
double max_val = 0;
for (int j = 0; j < k_a; ++j) {
cin >> a[i][j];
max_val = max(max_val, abs(a[i][j]));
}
sa[i] = max_val / 127.0;
}
vector<vector<int>> qa(m, vector<int>(k_a));
for (int i = 0; i < m; ++i) {
for (int j = 0; j < k_a; ++j) {
if (sa[i] == 0) qa[i][j] = 0;
else {
long long val = banker_round(a[i][j] / sa[i]);
qa[i][j] = (int)max(-127LL, min(127LL, val));
}
}
}
int k_b, n;
cin >> k_b >> n;
vector<vector<double>> b(k_b, vector<double>(n));
for (int i = 0; i < k_b; ++i) {
for (int j = 0; j < n; ++j) {
cin >> b[i][j];
}
}
vector<double> sb(n);
vector<vector<int>> qb(k_b, vector<int>(n));
for (int j = 0; j < n; ++j) {
double max_val = 0;
for (int i = 0; i < k_b; ++i) {
max_val = max(max_val, abs(b[i][j]));
}
sb[j] = max_val / 127.0;
for (int i = 0; i < k_b; ++i) {
if (sb[j] == 0) qb[i][j] = 0;
else {
long long val = banker_round(b[i][j] / sb[j]);
qb[i][j] = (int)max(-127LL, min(127LL, val));
}
}
}
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
long long sum = 0;
for (int l = 0; l < k_a; ++l) {
sum += (long long)qa[i][l] * qb[l][j];
}
double res = (double)sum * sa[i] * sb[j];
// 使用 fixed 和 setprecision(2) 保留两位小数
cout << fixed << setprecision(2) << res << (j == n - 1 ? "" : " ");
}
cout << endl;
}
return 0;
}
import java.util.*;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.math.RoundingMode;
public class Main {
// 限制量化值在 [-127, 127] 范围内
private static int clip(double x) {
long rounded = (long) Math.rint(x); // Math.rint 实现银行家舍入
if (rounded > 127) return 127;
if (rounded < -127) return -127;
return (int) rounded;
}
public static void main(String[] args) {
// 使用 Locale.US 确保小数点为 "."
Scanner sc = new Scanner(System.in).useLocale(Locale.US);
// --- 矩阵 A 量化 ---
int m = sc.nextInt();
int k = sc.nextInt();
double[][] a = new double[m][k];
double[] sa = new double[m];
int[][] qa = new int[m][k];
for (int i = 0; i < m; i++) {
double maxAbs = 0;
for (int j = 0; j < k; j++) {
a[i][j] = sc.nextDouble();
maxAbs = Math.max(maxAbs, Math.abs(a[i][j]));
}
sa[i] = maxAbs / 127.0;
for (int j = 0; j < k; j++) {
if (sa[i] == 0) qa[i][j] = 0;
else qa[i][j] = clip(a[i][j] / sa[i]);
}
}
// --- 矩阵 B 量化 ---
int kb = sc.nextInt();
int n = sc.nextInt();
double[][] b = new double[kb][n];
for (int i = 0; i < kb; i++) {
for (int j = 0; j < n; j++) {
b[i][j] = sc.nextDouble();
}
}
double[] sb = new double[n];
int[][] qb = new int[kb][n];
for (int j = 0; j < n; j++) {
double maxAbs = 0;
for (int i = 0; i < kb; i++) {
maxAbs = Math.max(maxAbs, Math.abs(b[i][j]));
}
sb[j] = maxAbs / 127.0;
for (int i = 0; i < kb; i++) {
if (sb[j] == 0) qb[i][j] = 0;
else qb[i][j] = clip(b[i][j] / sb[j]);
}
}
// --- 矩阵乘法、还原与输出 ---
// 配置输出格式为银行家舍入(HALF_EVEN)
DecimalFormat df = new DecimalFormat("0.00");
df.setRoundingMode(RoundingMode.HALF_EVEN);
DecimalFormatSymbols symbols = new DecimalFormatSymbols(Locale.US);
df.setDecimalFormatSymbols(symbols);
for (int i = 0; i < m; i++) {
StringBuilder sbLine = new StringBuilder();
for (int j = 0; j < n; j++) {
long sum = 0;
for (int l = 0; l < k; l++) {
sum += (long) qa[i][l] * qb[l][j];
}
double res = sum * sa[i] * sb[j];
sbLine.append(df.format(res));
if (j < n - 1) sbLine.append(" ");
}
System.out.println(sbLine.toString());
}
}
}
def solve():
# 读取矩阵 A
line_a = input().split()
if not line_a: return
m, k = map(int, line_a)
a = []
sa = []
qa = []
for i in range(m):
row = list(map(float, input().split()))
a.append(row)
max_abs = max([abs(x) for x in row])
s = max_abs / 127.0
sa.append(s)
q_row = []
for x in row:
if s == 0:
q_row.append(0)
else:
# Python 3 的 round() 默认就是银行家舍入
val = round(x / s)
q_row.append(max(-127, min(127, val)))
qa.append(q_row)
# 读取矩阵 B
line_b = input().split()
if not line_b: return
kb, n = map(int, line_b)
b = []
for _ in range(kb):
b.append(list(map(float, input().split())))
sb = []
qb = [[0] * n for _ in range(kb)]
for j in range(n):
col = [b[i][j] for i in range(kb)]
max_abs = max([abs(x) for x in col])
s = max_abs / 127.0
sb.append(s)
for i in range(kb):
if s == 0:
qb[i][j] = 0
else:
val = round(b[i][j] / s)
qb[i][j] = max(-127, min(127, val))
# 矩阵乘法
for i in range(m):
res_row = []
for j in range(n):
sum_val = 0
for l in range(k):
sum_val += qa[i][l] * qb[l][j]
# 还原
fp_val = sum_val * sa[i] * sb[j]
res_row.append(format(fp_val, '.2f'))
print(" ".join(res_row))
solve()
算法及复杂度
- 算法:矩阵计算 + 对称量化处理
- 时间复杂度:
。量化过程耗时
,核心计算量在于整数域的矩阵乘法。
- 空间复杂度:
。需要存储矩阵元素、量化后的整数矩阵以及缩放因子。

京公网安备 11010502036488号