题目链接

INT8 非对称量化下的全连接与误差评估

题目描述

在移动端或边缘设备上,浮点运算成本较高。常见做法是将输入向量和全连接层权重做 INT8 非对称量化(按张量整体 per-tensor),用整数在量化域直接做点积,最后用反量化结果评估与原始浮点结果的误差。

任务:

  1. 对输入向量 和权重矩阵 分别做 INT8 非对称量化(范围 ,不加偏置),输出量化域的 个整数点积结果。
  2. 将量化后的 分别反量化为 ,计算二者在浮点域的全连接输出,与原始 的浮点输出做均方误差 MSE,并输出 的整数。

量化/反量化细节 (per-tensor):

    • ,则 ,量化结果全为 ;反量化直接取
  • 量化: ,其中 为就近取偶。
  • 反量化:
  • MSE 四舍五入采用 half-up(即对 做 “ 下取整”)。

输入描述

  • 第一行: (输入向量维度)
  • 第二行: 个浮点数(输入向量
  • 第三行: (权重矩阵维度)
  • 接着 行: 每行 个浮点数(权重矩阵

输出描述

  • 第一行: 个整数(使用 计算的量化域全连接输出)
  • 第二行: 1 个整数(

解题思路

本题的核心是精确模拟一个简化的神经网络量化与反量化过程。我们需要严格按照题目定义的公式和步骤进行计算,没有复杂的算法思想,但需要注意实现的细节,特别是浮点数处理和取整规则。

整个流程可以分解为以下几个主要步骤:

  1. 数据读取与准备:

    • 读取输入向量 (维度为 ) 和权重矩阵 (维度为 )。
    • 由于权重矩阵 的量化是 per-tensor(按张量整体)的,我们需要遍历整个 矩阵,找出所有元素中的最大值 和最小值 。对向量 也做同样处理得到
  2. 量化 (Quantization):

    • 计算缩放因子 : 对 分别计算 。需要处理 的特殊情况,此时
    • 执行量化: 对每个浮点值 ,应用量化公式:
      • 这里的 指的是“就近取偶”规则(例如,2.5 取整为 2,3.5 取整为 4)。
      • 函数确保结果落在 INT8 的范围 内。
      • 的特例下,所有量化结果 均为
    • 经过此步骤,我们得到量化后的整数向量 和整数矩阵
  3. 量化域计算:

    • 使用量化后的 计算 个点积。对于 的每一行 ,计算
    • 这是本题的第一部分输出。
  4. 反量化 (De-quantization) 与误差评估:

    • 执行反量化: 使用之前计算的 ,将 转换回浮点域,得到 。公式为:
    • 计算两种浮点结果:
      1. 原始结果: 使用原始的 计算浮点点积
      2. 反量化结果: 使用 计算浮点点积
    • 计算均方误差 (MSE):
    • 处理最终输出: 对 进行处理并输出:
      • 这里的 是标准的“四舍五入”规则 (例如,2.5 取整为 3)。

代码

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <cmath>
#include <limits>

using namespace std;

// 就近取偶
long long round_half_even(double val) {
    return static_cast<long long>(rint(val));
}

// 四舍五入
long long round_half_up(double val) {
    return static_cast<long long>(floor(val + 0.5));
}

// clamp
int clamp(long long val, int min_val, int max_val) {
    return max(min_val, min((int)val, max_val));
}

int main() {
    int n;
    cin >> n;
    vector<double> x(n);
    double x_min = numeric_limits<double>::max();
    double x_max = numeric_limits<double>::lowest();
    for (int i = 0; i < n; ++i) {
        cin >> x[i];
        x_min = min(x_min, x[i]);
        x_max = max(x_max, x[i]);
    }

    int m;
    cin >> m >> n;
    vector<vector<double>> W(m, vector<double>(n));
    double W_min = numeric_limits<double>::max();
    double W_max = numeric_limits<double>::lowest();
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            cin >> W[i][j];
            W_min = min(W_min, W[i][j]);
            W_max = max(W_max, W[i][j]);
        }
    }

    // 量化
    double x_scale = (x_max == x_min) ? 0.0 : (x_max - x_min) / 255.0;
    vector<int> x_quant(n);
    for (int i = 0; i < n; ++i) {
        if (x_scale == 0.0) {
            x_quant[i] = -128;
        } else {
            x_quant[i] = clamp(round_half_even((x[i] - x_min) / x_scale) - 128, -128, 127);
        }
    }

    double W_scale = (W_max == W_min) ? 0.0 : (W_max - W_min) / 255.0;
    vector<vector<int>> W_quant(m, vector<int>(n));
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            if (W_scale == 0.0) {
                W_quant[i][j] = -128;
            } else {
                W_quant[i][j] = clamp(round_half_even((W[i][j] - W_min) / W_scale) - 128, -128, 127);
            }
        }
    }

    // 量化域计算
    vector<long long> y_quant(m);
    for (int i = 0; i < m; ++i) {
        long long dot_product = 0;
        for (int j = 0; j < n; ++j) {
            dot_product += (long long)x_quant[j] * W_quant[i][j];
        }
        y_quant[i] = dot_product;
    }

    for (int i = 0; i < m; ++i) {
        cout << y_quant[i] << (i == m - 1 ? "" : " ");
    }
    cout << endl;

    // 反量化
    vector<double> x_dequant(n);
    for (int i = 0; i < n; ++i) {
        x_dequant[i] = (x_quant[i] + 128) * x_scale + x_min;
    }

    vector<vector<double>> W_dequant(m, vector<double>(n));
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            W_dequant[i][j] = (W_quant[i][j] + 128) * W_scale + W_min;
        }
    }

    // 误差评估
    vector<double> y_float(m);
    for (int i = 0; i < m; ++i) {
        double dot_product = 0;
        for (int j = 0; j < n; ++j) {
            dot_product += x[j] * W[i][j];
        }
        y_float[i] = dot_product;
    }

    vector<double> y_dequant(m);
    for (int i = 0; i < m; ++i) {
        double dot_product = 0;
        for (int j = 0; j < n; ++j) {
            dot_product += x_dequant[j] * W_dequant[i][j];
        }
        y_dequant[i] = dot_product;
    }

    double mse = 0;
    for (int i = 0; i < m; ++i) {
        mse += pow(y_float[i] - y_dequant[i], 2);
    }
    mse /= m;

    cout << round_half_up(mse * 100000) << endl;

    return 0;
}
import java.util.Scanner;
import java.util.Arrays;
import java.lang.Math;

public class Main {
    // 就近取偶
    private static long round_half_even(double val) {
        return (long) Math.rint(val);
    }

    // 四舍五入
    private static long round_half_up(double val) {
        return (long) Math.floor(val + 0.5);
    }

    // clamp
    private static int clamp(long val, int min_val, int max_val) {
        return Math.max(min_val, Math.min((int)val, max_val));
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int n = sc.nextInt();
        double[] x = new double[n];
        double x_min = Double.MAX_VALUE;
        double x_max = -Double.MAX_VALUE;
        for (int i = 0; i < n; i++) {
            x[i] = sc.nextDouble();
            x_min = Math.min(x_min, x[i]);
            x_max = Math.max(x_max, x[i]);
        }

        int m = sc.nextInt();
        sc.nextInt(); // consume n
        double[][] W = new double[m][n];
        double W_min = Double.MAX_VALUE;
        double W_max = -Double.MAX_VALUE;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                W[i][j] = sc.nextDouble();
                W_min = Math.min(W_min, W[i][j]);
                W_max = Math.max(W_max, W[i][j]);
            }
        }

        // 量化
        double x_scale = (x_max == x_min) ? 0.0 : (x_max - x_min) / 255.0;
        int[] x_quant = new int[n];
        for (int i = 0; i < n; i++) {
            if (x_scale == 0.0) {
                x_quant[i] = -128;
            } else {
                x_quant[i] = clamp(round_half_even((x[i] - x_min) / x_scale) - 128, -128, 127);
            }
        }

        double W_scale = (W_max == W_min) ? 0.0 : (W_max - W_min) / 255.0;
        int[][] W_quant = new int[m][n];
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                if (W_scale == 0.0) {
                    W_quant[i][j] = -128;
                } else {
                    W_quant[i][j] = clamp(round_half_even((W[i][j] - W_min) / W_scale) - 128, -128, 127);
                }
            }
        }

        // 量化域计算
        long[] y_quant = new long[m];
        for (int i = 0; i < m; i++) {
            long dot_product = 0;
            for (int j = 0; j < n; j++) {
                dot_product += (long) x_quant[j] * W_quant[i][j];
            }
            y_quant[i] = dot_product;
        }

        for (int i = 0; i < m; i++) {
            System.out.print(y_quant[i] + (i == m - 1 ? "" : " "));
        }
        System.out.println();

        // 反量化
        double[] x_dequant = new double[n];
        for (int i = 0; i < n; i++) {
            x_dequant[i] = (x_quant[i] + 128) * x_scale + x_min;
        }

        double[][] W_dequant = new double[m][n];
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                W_dequant[i][j] = (W_quant[i][j] + 128) * W_scale + W_min;
            }
        }

        // 误差评估
        double[] y_float = new double[m];
        for (int i = 0; i < m; i++) {
            double dot_product = 0;
            for (int j = 0; j < n; j++) {
                dot_product += x[j] * W[i][j];
            }
            y_float[i] = dot_product;
        }

        double[] y_dequant = new double[m];
        for (int i = 0; i < m; i++) {
            double dot_product = 0;
            for (int j = 0; j < n; j++) {
                dot_product += x_dequant[j] * W_dequant[i][j];
            }
            y_dequant[i] = dot_product;
        }

        double mse = 0;
        for (int i = 0; i < m; i++) {
            mse += Math.pow(y_float[i] - y_dequant[i], 2);
        }
        mse /= m;
        
        System.out.println(round_half_up(mse * 100000));
    }
}
import math

def round_half_up(n):
    return math.floor(n + 0.5)

def clamp(val, min_val, max_val):
    return max(min_val, min(val, max_val))

def solve():
    n = int(input())
    x = list(map(float, input().split()))
    x_min, x_max = min(x), max(x)

    m, _ = map(int, input().split())
    W = []
    W_flat = []
    for _ in range(m):
        row = list(map(float, input().split()))
        W.append(row)
        W_flat.extend(row)
    
    W_min = min(W_flat) if W_flat else 0
    W_max = max(W_flat) if W_flat else 0

    # 量化
    x_scale = 0.0 if x_max == x_min else (x_max - x_min) / 255.0
    x_quant = []
    for val in x:
        if x_scale == 0.0:
            x_quant.append(-128)
        else:
            q_val = round((val - x_min) / x_scale) - 128
            x_quant.append(clamp(q_val, -128, 127))

    W_scale = 0.0 if W_max == W_min else (W_max - W_min) / 255.0
    W_quant = []
    for i in range(m):
        row_quant = []
        for j in range(n):
            if W_scale == 0.0:
                row_quant.append(-128)
            else:
                q_val = round((W[i][j] - W_min) / W_scale) - 128
                row_quant.append(clamp(q_val, -128, 127))
        W_quant.append(row_quant)

    # 量化域计算
    y_quant = []
    for i in range(m):
        dot_product = sum(x_quant[j] * W_quant[i][j] for j in range(n))
        y_quant.append(dot_product)
        
    print(*y_quant)

    # 反量化
    x_dequant = [(q + 128) * x_scale + x_min for q in x_quant]
    W_dequant = [[(q + 128) * W_scale + W_min for q in row] for row in W_quant]

    # 误差评估
    y_float = [sum(x[j] * W[i][j] for j in range(n)) for i in range(m)]
    y_dequant = [sum(x_dequant[j] * W_dequant[i][j] for j in range(n)) for i in range(m)]

    mse = sum((y_float[i] - y_dequant[i]) ** 2 for i in range(m)) / m
    
    print(int(round_half_up(mse * 100000)))

solve()

算法及复杂度

  • 算法: 数值模拟
  • 时间复杂度: - 主要开销在于遍历权重矩阵 以寻找最大/最小值、进行量化/反量化以及计算 次长度为 的点积。
  • 空间复杂度: - 需要存储输入的向量 和矩阵 ,以及它们的量化和反量化版本。