实现Masked Multi-Head Self-Attention
题意
给定输入张量 (形状
)和四个权重矩阵
(均为
),以及 head 数量
,要求实现完整的 Masked Multi-Head Self-Attention 计算,输出最终结果(保留两位小数)。
思路
这道题没有算法上的难度,纯粹是按步骤实现 Transformer 里 Decoder 用的 Masked Multi-Head Self-Attention。关键是把每一步的矩阵运算和维度变换搞对。
整个流程分七步,我们一步步理清楚:
第一步:线性投影。 用输入 分别乘以
得到
,形状不变,仍然是
。
第二步:拆成多头。 设 。把最后一维按 head 切开,从
变成
。具体来说,第
个 head 取每个 token 向量的第
段。
第三步:算注意力分数。 对每个 head 独立算 ,得到
的矩阵。
第四步:因果掩码。 这是 "Masked" 的关键——位置 只能看到
的位置。把
的分数设成一个很大的负数(比如
),这样 softmax 之后这些位置的权重基本是 0。
第五步:数值稳定的 Softmax。 对每一行,先减去行最大值再取指数,避免溢出:
$$
第六步:加权求和。 用 softmax 权重对 做加权求和。
第七步:拼回去,最终投影。 把所有 head 的输出沿最后一维拼接回 ,再乘
得到最终输出。
复杂度
设 为 batch size,
为序列长度,
为
。
- 线性投影:
- 注意力分数:
- 总体
代码
import math
def solve():
line = input().strip()
parts = line.split(';')
num_heads = int(parts[0].strip())
X = eval(parts[1].strip())
W_Q = eval(parts[2].strip())
W_K = eval(parts[3].strip())
W_V = eval(parts[4].strip())
W_O = eval(parts[5].strip())
batch = len(X)
seq = len(X[0])
d_model = len(X[0][0])
d_k = d_model // num_heads
def matmul_3d_2d(A, B):
res = []
for b in range(len(A)):
br = []
for s in range(len(A[b])):
row = []
for j in range(len(B[0])):
val = sum(A[b][s][k] * B[k][j] for k in range(len(B)))
row.append(val)
br.append(row)
res.append(br)
return res
Q = matmul_3d_2d(X, W_Q)
K = matmul_3d_2d(X, W_K)
V = matmul_3d_2d(X, W_V)
def reshape_to_heads(M):
res = []
for b in range(batch):
heads = []
for h in range(num_heads):
head = []
for s in range(seq):
head.append(M[b][s][h * d_k:(h + 1) * d_k])
heads.append(head)
res.append(heads)
return res
Q_h, K_h, V_h = reshape_to_heads(Q), reshape_to_heads(K), reshape_to_heads(V)
scale = math.sqrt(d_k)
NEG_INF = -1e9
# 注意力分数 + 掩码 + softmax + 加权求和
attn_out = []
for b in range(batch):
b_out = []
for h in range(num_heads):
# scores: [seq, seq]
scores = []
for i in range(seq):
row = []
for j in range(seq):
val = sum(Q_h[b][h][i][k] * K_h[b][h][j][k] for k in range(d_k))
row.append(val / scale if j <= i else NEG_INF)
row_max = max(row)
exps = [math.exp(x - row_max) for x in row]
s = sum(exps)
weights = [e / s for e in exps]
# 直接算这一行的输出
out_row = []
for k in range(d_k):
out_row.append(sum(weights[j] * V_h[b][h][j][k] for j in range(seq)))
scores.append(out_row)
b_out.append(scores)
attn_out.append(b_out)
# 拼接 + W_O
concat = []
for b in range(batch):
bc = []
for s in range(seq):
row = []
for h in range(num_heads):
row.extend(attn_out[b][h][s])
bc.append(row)
concat.append(bc)
output = matmul_3d_2d(concat, W_O)
def fmt(lst):
if isinstance(lst[0], list):
return '[' + ', '.join(fmt(x) for x in lst) + ']'
return '[' + ', '.join(f"{v:.2f}" for v in lst) + ']'
for b in range(batch):
for s in range(seq):
for d in range(d_model):
output[b][s][d] = round(output[b][s][d], 2)
print(fmt(output))
solve()

京公网安备 11010502036488号