题目链接
实现Masked Multi-Head Self-Attention
题目描述
给定一个批次的序列输入 (形状为
[batch, seq, d_model]
)以及四个权重矩阵 (均为
d_model × d_model
),你需要手动实现一个带因果掩码(Causal Mask)的多头自注意力机制。
核心步骤:
- 线性投射: 计算
,
,
。
- 多头拆分: 将
的最后一维
d_model
拆分为num_heads
个头,每个头的维度为d_k = d_model / num_heads
。形状从[batch, seq, d_model]
变为[batch, num_heads, seq, d_k]
。 - 计算注意力分数: 计算
scores = (Q @ K^T) / sqrt(d_k)
。其中@
代表在最后两维进行矩阵乘法。 - 因果掩码: 对
scores
应用一个下三角掩码,将上三角部分(不含对角线)的值设置为一个极小的负数(如-1e9
),以防止模型看到未来的信息。 - Softmax: 对加了掩码的
scores
在最后一个维度上进行数值稳定的 Softmax 操作,得到注意力权重。 - 加权求和: 将 Softmax 得到的权重与
相乘
attention = weights @ V
。 - 多头合并与输出投射: 将多头的结果合并,形状从
[batch, num_heads, seq, d_k]
恢复为[batch, seq, d_model]
,然后乘以输出权重矩阵。
输出要求:结果保留两位小数,并以 Python 列表格式输出。
解题思路
本题是一次对 Transformer 模型核心组件的底层模拟实现。解题的关键在于将复杂的张量运算分解为基础的循环和数学运算,并严格遵循题目给出的每一步。由于输入格式为 Python 风格的字符串,Python 的 eval
函数在解析上具有天然优势,而 C++ 和 Java 则需要手动编写一个解析器。
通用步骤:
-
输入解析:
- Python: 使用
str.split(';')
分割参数,再用eval()
将矩阵字符串直接转换为嵌套列表。 - C++/Java: 需要手动编写一个解析器。一个可行的策略是,对于矩阵字符串,通过追踪方括号
[]
的嵌套深度来找到顶层逗号,
,从而分割出子矩阵/向量,再进行递归解析。
- Python: 使用
-
辅助函数:
- 需要实现一个
matmul
函数,用于计算两个二维矩阵(List<List>
或vector<vector>
)的乘法。 - 所有计算过程应使用浮点数(
double
)以保证精度。
- 需要实现一个
-
模拟注意力计算:
- Step 1 (线性投射): 遍历
batch
维度,对每个X[b]
和权重矩阵做matmul
。 - Step 2 (多头拆分): 实现一个
reshape
函数。通过循环将[batch, seq, d_model]
形状的张量转换为[batch, num_heads, seq, d_k]
。 - Step 3 (计算分数): 遍历
batch
和num_heads
维度。对每个头的矩阵和
矩阵的转置做
matmul
,然后将结果除以。
- Step 4 (掩码): 对上一步得到的分数矩阵,用一个嵌套循环遍历其最后两维,将上三角部分的值设为
-1e9
。 - Step 5 (Softmax): 实现一个数值稳定的
softmax
函数。对于分数矩阵的每一行,先减去该行的最大值,再计算exp
,最后归一化。 - Step 6 (加权求和): 再次调用
matmul
,将 Softmax 权重矩阵与多头拆分后的相乘。
- Step 7 (合并与输出): 实现一个“逆向”的
reshape
操作,将[batch, num_heads, seq, d_k]
合并回[batch, seq, d_model]
。最后,再次调用matmul
将结果与相乘。
- Step 1 (线性投射): 遍历
-
格式化输出:
- 遍历最终结果张量,将每个浮点数四舍五入到两位小数,并按题目要求的 Python 列表格式打印。
代码
import math
import sys
# 为了简洁,直接在主函数内实现所有逻辑
def main():
# 1. 解析输入
line = input()
parts = line.split(';')
num_heads = int(parts[0])
X = eval(parts[1])
W_Q = eval(parts[2])
W_K = eval(parts[3])
W_V = eval(parts[4])
W_O = eval(parts[5])
batch, seq, d_model = len(X), len(X[0]), len(X[0][0])
d_k = d_model // num_heads
# 辅助函数:矩阵乘法
def matmul(A, B):
rows_A, cols_A = len(A), len(A[0])
rows_B, cols_B = len(B), len(B[0])
C = [[0.0 for _ in range(cols_B)] for _ in range(rows_A)]
for i in range(rows_A):
for j in range(cols_B):
for k in range(cols_A):
C[i][j] += A[i][k] * B[k][j]
return C
# Step 1: 线性投射
Q = [matmul(X[b], W_Q) for b in range(batch)]
K = [matmul(X[b], W_K) for b in range(batch)]
V = [matmul(X[b], W_V) for b in range(batch)]
# Step 2: 多头拆分 [b, s, d_m] -> [b, h, s, d_k]
Q_heads = [[[[0.0 for _ in range(d_k)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
K_heads = [[[[0.0 for _ in range(d_k)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
V_heads = [[[[0.0 for _ in range(d_k)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
for b in range(batch):
for s in range(seq):
for h in range(num_heads):
for i in range(d_k):
idx = h * d_k + i
Q_heads[b][h][s][i] = Q[b][s][idx]
K_heads[b][h][s][i] = K[b][s][idx]
V_heads[b][h][s][i] = V[b][s][idx]
# Step 3: 计算注意力分数
scores = [[[[0.0 for _ in range(seq)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
scale = math.sqrt(d_k)
for b in range(batch):
for h in range(num_heads):
# Q[b][h] @ K[b][h].T
q_slice = Q_heads[b][h]
k_slice = K_heads[b][h]
for i in range(seq):
for j in range(seq):
dot_product = sum(q_slice[i][k] * k_slice[j][k] for k in range(d_k))
scores[b][h][i][j] = dot_product / scale
# Step 4: 因果掩码
for b in range(batch):
for h in range(num_heads):
for i in range(seq):
for j in range(i + 1, seq):
scores[b][h][i][j] = -1e9
# Step 5: Softmax
weights = [[[[0.0 for _ in range(seq)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
for b in range(batch):
for h in range(num_heads):
for i in range(seq):
row = scores[b][h][i]
max_val = max(row)
exps = [math.exp(v - max_val) for v in row]
sum_exps = sum(exps)
weights[b][h][i] = [e / sum_exps for e in exps]
# Step 6: 加权求和
attention = [[[[0.0 for _ in range(d_k)] for _ in range(seq)] for _ in range(num_heads)] for _ in range(batch)]
for b in range(batch):
for h in range(num_heads):
attention[b][h] = matmul(weights[b][h], V_heads[b][h])
# Step 7: 多头合并与输出投射
# a. 合并 [b, h, s, d_k] -> [b, s, d_m]
concat_attention = [[[0.0 for _ in range(d_model)] for _ in range(seq)] for _ in range(batch)]
for b in range(batch):
for s in range(seq):
for h in range(num_heads):
for i in range(d_k):
concat_attention[b][s][h * d_k + i] = attention[b][h][s][i]
# b. 输出投射
final_output = [matmul(concat_attention[b], W_O) for b in range(batch)]
# 格式化输出
rounded_output = [[[round(val, 2) for val in vec] for vec in mat] for mat in final_output]
# Python list to string
def to_str(l):
if isinstance(l, list):
return "[" + ", ".join(to_str(i) for i in l) + "]"
return f"{l:.2f}"
print(to_str(rounded_output))
if __name__ == "__main__":
main()
算法及复杂度
- 算法:Transformer 自注意力机制模拟
- 时间复杂度:
。
- 线性投射为
。
- 注意力分数计算为
。
- 后续步骤复杂度较低。瓶颈通常在注意力分数的计算上。
- 线性投射为
- 空间复杂度:
,主要由存储注意力分数矩阵
scores
决定。