INT8 非对称量化下的全连接与误差评估

题目分析

给定输入向量 (长度 )和权重矩阵 ),需要:

  1. 分别进行 per-tensor INT8 非对称量化(范围 )。
  2. 在量化域计算全连接层的点积输出。
  3. 分别反量化后再做浮点全连接,与原始浮点全连接的结果计算 MSE,输出

思路

模拟量化/反量化流程

核心是严格按照公式实现三个步骤:

量化参数计算:对一组值 ,计算 。若 ,则 ,所有量化值为

量化,其中 round 使用银行家舍入(四舍六入五成双)。

反量化

关键细节:

  • 按整个向量做 per-tensor 量化,整个矩阵(而非逐行)做 per-tensor 量化。
  • 量化域点积:,输出整数。
  • MSE 的计算对象是浮点全连接输出 反量化全连接输出 (用反量化后的 做矩阵乘法),

整体只需按部就班地模拟,没有复杂的算法技巧。

代码

import sys
import math

def solve():
    data = sys.stdin.read().split()
    idx = 0
    n = int(data[idx]); idx += 1
    x = [float(data[idx + i]) for i in range(n)]; idx += n
    m = int(data[idx]); idx += 1
    n2 = int(data[idx]); idx += 1
    W = []
    for i in range(m):
        row = [float(data[idx + j]) for j in range(n)]
        idx += n
        W.append(row)

    def quantize(vals):
        mn, mx = min(vals), max(vals)
        if mn == mx:
            return [-128] * len(vals), 0.0, mn
        scale = (mx - mn) / 255.0
        q = [max(-128, min(127, int(round((v - mn) / scale)) - 128)) for v in vals]
        return q, scale, mn

    # per-tensor 量化
    x_q, x_scale, x_min = quantize(x)
    w_flat = [v for row in W for v in row]
    w_q_flat, w_scale, w_min = quantize(w_flat)
    W_q = [w_q_flat[i * n:(i + 1) * n] for i in range(m)]

    # 量化域点积
    y_quant = []
    for i in range(m):
        y_quant.append(sum(x_q[j] * W_q[i][j] for j in range(n)))

    # 反量化
    x_deq = [(x_q[j] + 128) * x_scale + x_min for j in range(n)]
    W_deq = [[(W_q[i][j] + 128) * w_scale + w_min for j in range(n)] for i in range(m)]

    # 浮点全连接 vs 反量化全连接
    y_float = [sum(x[j] * W[i][j] for j in range(n)) for i in range(m)]
    y_deq = [sum(x_deq[j] * W_deq[i][j] for j in range(n)) for i in range(m)]

    # MSE
    mse = sum((y_float[i] - y_deq[i]) ** 2 for i in range(m)) / m
    result = math.floor(mse * 100000 + 0.5)

    print(' '.join(map(str, y_quant)))
    print(result)

solve()

复杂度分析

  • 时间复杂度:,量化、点积和 MSE 计算均为矩阵规模的线性扫描。
  • 空间复杂度:,存储矩阵及其量化/反量化结果。