实现Masked Multi-Head Self-Attention

题意

给定输入张量 (形状 )和四个权重矩阵 (均为 ),以及 head 数量 ,要求实现完整的 Masked Multi-Head Self-Attention 计算,输出最终结果(保留两位小数)。

思路

这道题没有算法上的难度,纯粹是按步骤实现 Transformer 里 Decoder 用的 Masked Multi-Head Self-Attention。关键是把每一步的矩阵运算和维度变换搞对。

整个流程分七步,我们一步步理清楚:

第一步:线性投影。 用输入 分别乘以 得到 ,形状不变,仍然是

第二步:拆成多头。。把最后一维按 head 切开,从 变成 。具体来说,第 个 head 取每个 token 向量的第 段。

第三步:算注意力分数。 对每个 head 独立算 ,得到 的矩阵。

第四步:因果掩码。 这是 "Masked" 的关键——位置 只能看到 的位置。把 的分数设成一个很大的负数(比如 ),这样 softmax 之后这些位置的权重基本是 0。

第五步:数值稳定的 Softmax。 对每一行,先减去行最大值再取指数,避免溢出:

$$

第六步:加权求和。 用 softmax 权重对 做加权求和。

第七步:拼回去,最终投影。 把所有 head 的输出沿最后一维拼接回 ,再乘 得到最终输出。

复杂度

为 batch size, 为序列长度,

  • 线性投影:
  • 注意力分数:
  • 总体

代码

import math

def solve():
    line = input().strip()
    parts = line.split(';')
    num_heads = int(parts[0].strip())
    X = eval(parts[1].strip())
    W_Q = eval(parts[2].strip())
    W_K = eval(parts[3].strip())
    W_V = eval(parts[4].strip())
    W_O = eval(parts[5].strip())

    batch = len(X)
    seq = len(X[0])
    d_model = len(X[0][0])
    d_k = d_model // num_heads

    def matmul_3d_2d(A, B):
        res = []
        for b in range(len(A)):
            br = []
            for s in range(len(A[b])):
                row = []
                for j in range(len(B[0])):
                    val = sum(A[b][s][k] * B[k][j] for k in range(len(B)))
                    row.append(val)
                br.append(row)
            res.append(br)
        return res

    Q = matmul_3d_2d(X, W_Q)
    K = matmul_3d_2d(X, W_K)
    V = matmul_3d_2d(X, W_V)

    def reshape_to_heads(M):
        res = []
        for b in range(batch):
            heads = []
            for h in range(num_heads):
                head = []
                for s in range(seq):
                    head.append(M[b][s][h * d_k:(h + 1) * d_k])
                heads.append(head)
            res.append(heads)
        return res

    Q_h, K_h, V_h = reshape_to_heads(Q), reshape_to_heads(K), reshape_to_heads(V)

    scale = math.sqrt(d_k)
    NEG_INF = -1e9

    # 注意力分数 + 掩码 + softmax + 加权求和
    attn_out = []
    for b in range(batch):
        b_out = []
        for h in range(num_heads):
            # scores: [seq, seq]
            scores = []
            for i in range(seq):
                row = []
                for j in range(seq):
                    val = sum(Q_h[b][h][i][k] * K_h[b][h][j][k] for k in range(d_k))
                    row.append(val / scale if j <= i else NEG_INF)
                row_max = max(row)
                exps = [math.exp(x - row_max) for x in row]
                s = sum(exps)
                weights = [e / s for e in exps]
                # 直接算这一行的输出
                out_row = []
                for k in range(d_k):
                    out_row.append(sum(weights[j] * V_h[b][h][j][k] for j in range(seq)))
                scores.append(out_row)
            b_out.append(scores)
        attn_out.append(b_out)

    # 拼接 + W_O
    concat = []
    for b in range(batch):
        bc = []
        for s in range(seq):
            row = []
            for h in range(num_heads):
                row.extend(attn_out[b][h][s])
            bc.append(row)
        concat.append(bc)

    output = matmul_3d_2d(concat, W_O)

    def fmt(lst):
        if isinstance(lst[0], list):
            return '[' + ', '.join(fmt(x) for x in lst) + ']'
        return '[' + ', '.join(f"{v:.2f}" for v in lst) + ']'

    for b in range(batch):
        for s in range(seq):
            for d in range(d_model):
                output[b][s][d] = round(output[b][s][d], 2)
    print(fmt(output))

solve()