题目链接

实现Masked Multi-Head Self-Attention

题目描述

给定一个批次的序列输入 (形状为 [batch, seq, d_model])以及四个权重矩阵 (均为 d_model × d_model),你需要手动实现一个带因果掩码(Causal Mask)的多头自注意力机制。

核心步骤

  1. 线性投射: 计算 , ,
  2. 多头拆分: 将 的最后一维 d_model 拆分为 num_heads 个头,每个头的维度为 d_k = d_model / num_heads。形状从 [batch, seq, d_model] 变为 [batch, num_heads, seq, d_k]
  3. 计算注意力分数: 计算 scores = (Q @ K^T) / sqrt(d_k)。其中 @ 代表在最后两维进行矩阵乘法。
  4. 因果掩码: 对 scores 应用一个下三角掩码,将上三角部分(不含对角线)的值设置为一个极小的负数(如 -1e9),以防止模型看到未来的信息。
  5. Softmax: 对加了掩码的 scores 在最后一个维度上进行数值稳定的 Softmax 操作,得到注意力权重。
  6. 加权求和: 将 Softmax 得到的权重与 相乘 attention = weights @ V
  7. 多头合并与输出投射: 将多头的结果合并,形状从 [batch, num_heads, seq, d_k] 恢复为 [batch, seq, d_model],然后乘以输出权重矩阵

输出要求:结果保留两位小数,并以 Python 列表格式输出。

解题思路

本题是一次对 Transformer 模型核心组件的底层模拟实现。解题的关键在于将复杂的张量运算分解为基础的循环和数学运算,并严格遵循题目给出的每一步。由于输入格式为 Python 风格的字符串,Python 的 eval 函数在解析上具有天然优势,而 C++ 和 Java 则需要手动编写一个解析器。

通用步骤

  1. 输入解析:

    • Python: 使用 str.split(';') 分割参数,再用 eval() 将矩阵字符串直接转换为嵌套列表。
    • C++/Java: 需要手动编写一个解析器。一个可行的策略是,对于矩阵字符串,通过追踪方括号 [] 的嵌套深度来找到顶层逗号 ,,从而分割出子矩阵/向量,再进行递归解析。
  2. 辅助函数:

    • 需要实现一个 matmul 函数,用于计算两个二维矩阵(List<List>vector<vector>)的乘法。
    • 所有计算过程应使用浮点数(double)以保证精度。
  3. 模拟注意力计算:

    • Step 1 (线性投射): 遍历 batch 维度,对每个 X[b] 和权重矩阵做 matmul
    • Step 2 (多头拆分): 实现一个 reshape 函数。通过循环将 [batch, seq, d_model] 形状的张量转换为 [batch, num_heads, seq, d_k]
    • Step 3 (计算分数): 遍历 batchnum_heads 维度。对每个头的 矩阵和 矩阵的转置做 matmul,然后将结果除以
    • Step 4 (掩码): 对上一步得到的分数矩阵,用一个嵌套循环遍历其最后两维,将上三角部分的值设为 -1e9
    • Step 5 (Softmax): 实现一个数值稳定的 softmax 函数。对于分数矩阵的每一行,先减去该行的最大值,再计算 exp,最后归一化。
    • Step 6 (加权求和): 再次调用 matmul,将 Softmax 权重矩阵与多头拆分后的 相乘。
    • Step 7 (合并与输出): 实现一个“逆向”的 reshape 操作,将 [batch, num_heads, seq, d_k] 合并回 [batch, seq, d_model]。最后,再次调用 matmul 将结果与 相乘。
  4. 格式化输出:

    • 遍历最终结果张量,将每个浮点数四舍五入到两位小数,并按题目要求的 Python 列表格式打印。

代码

import math
import sys

# 为了简洁,直接在主函数内实现所有逻辑

def main():
    # 1. 解析输入
    line = input()
    parts = line.split(';')
    
    num_heads = int(parts[0])
    X = eval(parts[1])
    W_Q = eval(parts[2])
    W_K = eval(parts[3])
    W_V = eval(parts[4])
    W_O = eval(parts[5])

    batch, seq, d_model = len(X), len(X[0]), len(X[0][0])
    d_k = d_model // num_heads

    # 辅助函数:矩阵乘法
    def matmul(A, B):
        rows_A, cols_A = len(A), len(A[0])
        rows_B, cols_B = len(B), len(B[0])
        C = [[0.0 for _ in range(cols_B)] for _ in range(rows_A)]
        for i in range(rows_A):
            for j in range(cols_B):
                for k in range(cols_A):
                    C[i][j] += A[i][k] * B[k][j]
        return C

    # Step 1: 线性投射
    Q = [matmul(X[b], W_Q) for b in range(batch)]
    K = [matmul(X[b], W_K) for b in range(batch)]
    V = [matmul(X[b], W_V) for b in range(batch)]

    # Step 2: 多头拆分 [b, s, d_m] -> [b, h, s, d_k]
    Q_heads = [[[[0.0 for _ in range(d_k)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
    K_heads = [[[[0.0 for _ in range(d_k)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
    V_heads = [[[[0.0 for _ in range(d_k)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]

    for b in range(batch):
        for s in range(seq):
            for h in range(num_heads):
                for i in range(d_k):
                    idx = h * d_k + i
                    Q_heads[b][h][s][i] = Q[b][s][idx]
                    K_heads[b][h][s][i] = K[b][s][idx]
                    V_heads[b][h][s][i] = V[b][s][idx]

    # Step 3: 计算注意力分数
    scores = [[[[0.0 for _ in range(seq)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
    scale = math.sqrt(d_k)

    for b in range(batch):
        for h in range(num_heads):
            # Q[b][h] @ K[b][h].T
            q_slice = Q_heads[b][h]
            k_slice = K_heads[b][h]
            for i in range(seq):
                for j in range(seq):
                    dot_product = sum(q_slice[i][k] * k_slice[j][k] for k in range(d_k))
                    scores[b][h][i][j] = dot_product / scale

    # Step 4: 因果掩码
    for b in range(batch):
        for h in range(num_heads):
            for i in range(seq):
                for j in range(i + 1, seq):
                    scores[b][h][i][j] = -1e9

    # Step 5: Softmax
    weights = [[[[0.0 for _ in range(seq)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
    for b in range(batch):
        for h in range(num_heads):
            for i in range(seq):
                row = scores[b][h][i]
                max_val = max(row)
                exps = [math.exp(v - max_val) for v in row]
                sum_exps = sum(exps)
                weights[b][h][i] = [e / sum_exps for e in exps]
    
    # Step 6: 加权求和
    attention = [[[[0.0 for _ in range(d_k)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
    for b in range(batch):
        for h in range(num_heads):
            attention[b][h] = matmul(weights[b][h], V_heads[b][h])

    # Step 7: 多头合并与输出投射
    # a. 合并 [b, h, s, d_k] -> [b, s, d_m]
    concat_attention = [[[0.0 for _ in range(d_model)] for _ in range(seq)] for _ in range(batch)]
    for b in range(batch):
        for s in range(seq):
            for h in range(num_heads):
                for i in range(d_k):
                    concat_attention[b][s][h * d_k + i] = attention[b][h][s][i]
    
    # b. 输出投射
    final_output = [matmul(concat_attention[b], W_O) for b in range(batch)]

    # 格式化输出
    rounded_output = [[[round(val, 2) for val in vec] for vec in mat] for mat in final_output]
    
    # Python list to string
    def to_str(l):
        if isinstance(l, list):
            return "[" + ", ".join(to_str(i) for i in l) + "]"
        return f"{l:.2f}"

    print(to_str(rounded_output))


if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:Transformer 自注意力机制模拟
  • 时间复杂度:
    • 线性投射为
    • 注意力分数计算为
    • 后续步骤复杂度较低。瓶颈通常在注意力分数的计算上。
  • 空间复杂度:,主要由存储注意力分数矩阵 scores 决定。