题目链接
题目描述
在移动端或边缘设备上,浮点运算成本较高。常见做法是将输入向量和全连接层权重做 INT8 非对称量化(按张量整体 per-tensor),用整数在量化域直接做点积,最后用反量化结果评估与原始浮点结果的误差。
任务:
- 对输入向量
和权重矩阵
分别做 INT8 非对称量化(范围
,不加偏置),输出量化域的
个整数点积结果。
- 将量化后的
与
分别反量化为
、
,计算二者在浮点域的全连接输出,与原始
、
的浮点输出做均方误差 MSE,并输出
的整数。
量化/反量化细节 (per-tensor):
- 若
,则
,量化结果全为
;反量化直接取
。
- 若
- 量化:
,其中
为就近取偶。
- 反量化:
- MSE 四舍五入采用
half-up(即对做 “
下取整”)。
输入描述
- 第一行:
(输入向量维度)
- 第二行:
个浮点数(输入向量
)
- 第三行:
(权重矩阵维度)
- 接着
行: 每行
个浮点数(权重矩阵
)
输出描述
- 第一行:
个整数(使用
与
计算的量化域全连接输出)
- 第二行: 1 个整数(
)
解题思路
本题的核心是精确模拟一个简化的神经网络量化与反量化过程。我们需要严格按照题目定义的公式和步骤进行计算,没有复杂的算法思想,但需要注意实现的细节,特别是浮点数处理和取整规则。
整个流程可以分解为以下几个主要步骤:
-
数据读取与准备:
- 读取输入向量
(维度为
) 和权重矩阵
(维度为
)。
- 由于权重矩阵
的量化是
per-tensor(按张量整体)的,我们需要遍历整个矩阵,找出所有元素中的最大值
和最小值
。对向量
也做同样处理得到
和
。
- 读取输入向量
-
量化 (Quantization):
- 计算缩放因子
: 对
和
分别计算
。需要处理
的特殊情况,此时
。
- 执行量化: 对每个浮点值
,应用量化公式:
- 这里的
指的是“就近取偶”规则(例如,2.5 取整为 2,3.5 取整为 4)。
函数确保结果落在 INT8 的范围
内。
- 在
的特例下,所有量化结果
均为
。
- 这里的
- 经过此步骤,我们得到量化后的整数向量
和整数矩阵
。
- 计算缩放因子
-
量化域计算:
- 使用量化后的
和
计算
个点积。对于
的每一行
,计算
。
- 这是本题的第一部分输出。
- 使用量化后的
-
反量化 (De-quantization) 与误差评估:
- 执行反量化: 使用之前计算的
和
,将
和
转换回浮点域,得到
和
。公式为:
- 计算两种浮点结果:
- 原始结果: 使用原始的
和
计算浮点点积
。
- 反量化结果: 使用
和
计算浮点点积
。
- 原始结果: 使用原始的
- 计算均方误差 (MSE):
- 处理最终输出: 对
进行处理并输出:
。
- 这里的
是标准的“四舍五入”规则 (例如,2.5 取整为 3)。
- 这里的
- 执行反量化: 使用之前计算的
代码
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <cmath>
#include <limits>
using namespace std;
// 就近取偶
long long round_half_even(double val) {
return static_cast<long long>(rint(val));
}
// 四舍五入
long long round_half_up(double val) {
return static_cast<long long>(floor(val + 0.5));
}
// clamp
int clamp(long long val, int min_val, int max_val) {
return max(min_val, min((int)val, max_val));
}
int main() {
int n;
cin >> n;
vector<double> x(n);
double x_min = numeric_limits<double>::max();
double x_max = numeric_limits<double>::lowest();
for (int i = 0; i < n; ++i) {
cin >> x[i];
x_min = min(x_min, x[i]);
x_max = max(x_max, x[i]);
}
int m;
cin >> m >> n;
vector<vector<double>> W(m, vector<double>(n));
double W_min = numeric_limits<double>::max();
double W_max = numeric_limits<double>::lowest();
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
cin >> W[i][j];
W_min = min(W_min, W[i][j]);
W_max = max(W_max, W[i][j]);
}
}
// 量化
double x_scale = (x_max == x_min) ? 0.0 : (x_max - x_min) / 255.0;
vector<int> x_quant(n);
for (int i = 0; i < n; ++i) {
if (x_scale == 0.0) {
x_quant[i] = -128;
} else {
x_quant[i] = clamp(round_half_even((x[i] - x_min) / x_scale) - 128, -128, 127);
}
}
double W_scale = (W_max == W_min) ? 0.0 : (W_max - W_min) / 255.0;
vector<vector<int>> W_quant(m, vector<int>(n));
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
if (W_scale == 0.0) {
W_quant[i][j] = -128;
} else {
W_quant[i][j] = clamp(round_half_even((W[i][j] - W_min) / W_scale) - 128, -128, 127);
}
}
}
// 量化域计算
vector<long long> y_quant(m);
for (int i = 0; i < m; ++i) {
long long dot_product = 0;
for (int j = 0; j < n; ++j) {
dot_product += (long long)x_quant[j] * W_quant[i][j];
}
y_quant[i] = dot_product;
}
for (int i = 0; i < m; ++i) {
cout << y_quant[i] << (i == m - 1 ? "" : " ");
}
cout << endl;
// 反量化
vector<double> x_dequant(n);
for (int i = 0; i < n; ++i) {
x_dequant[i] = (x_quant[i] + 128) * x_scale + x_min;
}
vector<vector<double>> W_dequant(m, vector<double>(n));
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
W_dequant[i][j] = (W_quant[i][j] + 128) * W_scale + W_min;
}
}
// 误差评估
vector<double> y_float(m);
for (int i = 0; i < m; ++i) {
double dot_product = 0;
for (int j = 0; j < n; ++j) {
dot_product += x[j] * W[i][j];
}
y_float[i] = dot_product;
}
vector<double> y_dequant(m);
for (int i = 0; i < m; ++i) {
double dot_product = 0;
for (int j = 0; j < n; ++j) {
dot_product += x_dequant[j] * W_dequant[i][j];
}
y_dequant[i] = dot_product;
}
double mse = 0;
for (int i = 0; i < m; ++i) {
mse += pow(y_float[i] - y_dequant[i], 2);
}
mse /= m;
cout << round_half_up(mse * 100000) << endl;
return 0;
}
import java.util.Scanner;
import java.util.Arrays;
import java.lang.Math;
public class Main {
// 就近取偶
private static long round_half_even(double val) {
return (long) Math.rint(val);
}
// 四舍五入
private static long round_half_up(double val) {
return (long) Math.floor(val + 0.5);
}
// clamp
private static int clamp(long val, int min_val, int max_val) {
return Math.max(min_val, Math.min((int)val, max_val));
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
double[] x = new double[n];
double x_min = Double.MAX_VALUE;
double x_max = -Double.MAX_VALUE;
for (int i = 0; i < n; i++) {
x[i] = sc.nextDouble();
x_min = Math.min(x_min, x[i]);
x_max = Math.max(x_max, x[i]);
}
int m = sc.nextInt();
sc.nextInt(); // consume n
double[][] W = new double[m][n];
double W_min = Double.MAX_VALUE;
double W_max = -Double.MAX_VALUE;
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
W[i][j] = sc.nextDouble();
W_min = Math.min(W_min, W[i][j]);
W_max = Math.max(W_max, W[i][j]);
}
}
// 量化
double x_scale = (x_max == x_min) ? 0.0 : (x_max - x_min) / 255.0;
int[] x_quant = new int[n];
for (int i = 0; i < n; i++) {
if (x_scale == 0.0) {
x_quant[i] = -128;
} else {
x_quant[i] = clamp(round_half_even((x[i] - x_min) / x_scale) - 128, -128, 127);
}
}
double W_scale = (W_max == W_min) ? 0.0 : (W_max - W_min) / 255.0;
int[][] W_quant = new int[m][n];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
if (W_scale == 0.0) {
W_quant[i][j] = -128;
} else {
W_quant[i][j] = clamp(round_half_even((W[i][j] - W_min) / W_scale) - 128, -128, 127);
}
}
}
// 量化域计算
long[] y_quant = new long[m];
for (int i = 0; i < m; i++) {
long dot_product = 0;
for (int j = 0; j < n; j++) {
dot_product += (long) x_quant[j] * W_quant[i][j];
}
y_quant[i] = dot_product;
}
for (int i = 0; i < m; i++) {
System.out.print(y_quant[i] + (i == m - 1 ? "" : " "));
}
System.out.println();
// 反量化
double[] x_dequant = new double[n];
for (int i = 0; i < n; i++) {
x_dequant[i] = (x_quant[i] + 128) * x_scale + x_min;
}
double[][] W_dequant = new double[m][n];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
W_dequant[i][j] = (W_quant[i][j] + 128) * W_scale + W_min;
}
}
// 误差评估
double[] y_float = new double[m];
for (int i = 0; i < m; i++) {
double dot_product = 0;
for (int j = 0; j < n; j++) {
dot_product += x[j] * W[i][j];
}
y_float[i] = dot_product;
}
double[] y_dequant = new double[m];
for (int i = 0; i < m; i++) {
double dot_product = 0;
for (int j = 0; j < n; j++) {
dot_product += x_dequant[j] * W_dequant[i][j];
}
y_dequant[i] = dot_product;
}
double mse = 0;
for (int i = 0; i < m; i++) {
mse += Math.pow(y_float[i] - y_dequant[i], 2);
}
mse /= m;
System.out.println(round_half_up(mse * 100000));
}
}
import math
def round_half_up(n):
return math.floor(n + 0.5)
def clamp(val, min_val, max_val):
return max(min_val, min(val, max_val))
def solve():
n = int(input())
x = list(map(float, input().split()))
x_min, x_max = min(x), max(x)
m, _ = map(int, input().split())
W = []
W_flat = []
for _ in range(m):
row = list(map(float, input().split()))
W.append(row)
W_flat.extend(row)
W_min = min(W_flat) if W_flat else 0
W_max = max(W_flat) if W_flat else 0
# 量化
x_scale = 0.0 if x_max == x_min else (x_max - x_min) / 255.0
x_quant = []
for val in x:
if x_scale == 0.0:
x_quant.append(-128)
else:
q_val = round((val - x_min) / x_scale) - 128
x_quant.append(clamp(q_val, -128, 127))
W_scale = 0.0 if W_max == W_min else (W_max - W_min) / 255.0
W_quant = []
for i in range(m):
row_quant = []
for j in range(n):
if W_scale == 0.0:
row_quant.append(-128)
else:
q_val = round((W[i][j] - W_min) / W_scale) - 128
row_quant.append(clamp(q_val, -128, 127))
W_quant.append(row_quant)
# 量化域计算
y_quant = []
for i in range(m):
dot_product = sum(x_quant[j] * W_quant[i][j] for j in range(n))
y_quant.append(dot_product)
print(*y_quant)
# 反量化
x_dequant = [(q + 128) * x_scale + x_min for q in x_quant]
W_dequant = [[(q + 128) * W_scale + W_min for q in row] for row in W_quant]
# 误差评估
y_float = [sum(x[j] * W[i][j] for j in range(n)) for i in range(m)]
y_dequant = [sum(x_dequant[j] * W_dequant[i][j] for j in range(n)) for i in range(m)]
mse = sum((y_float[i] - y_dequant[i]) ** 2 for i in range(m)) / m
print(int(round_half_up(mse * 100000)))
solve()
算法及复杂度
- 算法: 数值模拟
- 时间复杂度:
- 主要开销在于遍历权重矩阵
以寻找最大/最小值、进行量化/反量化以及计算
次长度为
的点积。
- 空间复杂度:
- 需要存储输入的向量
和矩阵
,以及它们的量化和反量化版本。

京公网安备 11010502036488号