医疗诊断模型的训练
题意
给定一个简单的"线性映射 + 线性分类"模型,需要完成三步计算:
- 前向传播:输入矩阵
(
)经过映射矩阵
(
)得到隐层
,再经过分类矩阵
(
)得到打分
,对
条记录取平均得到预测向量
(
维)。
- 计算 MSE 损失:
。
- SGD 更新:用学习率
对
和
各做一步梯度下降。
输出预测值、损失值、更新后的两个权重矩阵,保留 2 位小数。
思路
前向传播就是简单的矩阵乘法加平均,关键在于反向传播的梯度推导。
设 ,其中
,
。
MSE 对 的梯度:
$$
由于 是对
按行取平均,梯度均匀分配到每一行:
$$
然后按矩阵乘法的链式法则:
$$
$$
$$
最后用 SGD 更新:。
整个过程就是手写一遍神经网络的前向和反向传播,没有激活函数,纯线性运算。
时间复杂度 ,空间复杂度
。
代码
import sys
def solve():
data = sys.stdin.read().split('\n')
line0 = data[0].split(',')
L, D, K = int(line0[0]), int(line0[1]), int(line0[2])
eta = float(line0[3])
y = list(map(float, data[1].split(',')))
X_flat = list(map(float, data[2].split(',')))
X = [X_flat[i * D:(i + 1) * D] for i in range(L)]
W_mlp_flat = list(map(float, data[3].split(',')))
W_mlp = [W_mlp_flat[i * D:(i + 1) * D] for i in range(D)]
W_cls_flat = list(map(float, data[4].split(',')))
W_cls = [W_cls_flat[i * K:(i + 1) * K] for i in range(D)]
# 前向传播:H = X @ W_mlp (L x D)
H = [[sum(X[i][k] * W_mlp[k][j] for k in range(D)) for j in range(D)] for i in range(L)]
# S = H @ W_cls (L x K)
S = [[sum(H[i][k] * W_cls[k][j] for k in range(D)) for j in range(K)] for i in range(L)]
# y_pred = mean(S, axis=0)
y_pred = [sum(S[i][j] for i in range(L)) / L for j in range(K)]
# MSE
mse = sum((y_pred[j] - y[j]) ** 2 for j in range(K)) / K
# 反向传播
d_ypred = [2.0 * (y_pred[j] - y[j]) / K for j in range(K)]
d_S = [[d_ypred[j] / L for j in range(K)] for _ in range(L)]
d_W_cls = [[sum(H[l][i] * d_S[l][j] for l in range(L)) for j in range(K)] for i in range(D)]
d_H = [[sum(d_S[i][k] * W_cls[j][k] for k in range(K)) for j in range(D)] for i in range(L)]
d_W_mlp = [[sum(X[l][i] * d_H[l][j] for l in range(L)) for j in range(D)] for i in range(D)]
# SGD 更新
for i in range(D):
for j in range(D):
W_mlp[i][j] -= eta * d_W_mlp[i][j]
for i in range(D):
for j in range(K):
W_cls[i][j] -= eta * d_W_cls[i][j]
# 输出
print(','.join(f'{v:.2f}' for v in y_pred))
print(f'{mse:.2f}')
print(','.join(f'{W_mlp[i][j]:.2f}' for i in range(D) for j in range(D)))
print(','.join(f'{W_cls[i][j]:.2f}' for i in range(D) for j in range(K)))
solve()
复杂度分析
- 时间复杂度:
,主要是矩阵乘法的开销。
- 空间复杂度:
,存储中间矩阵
、
及梯度。

京公网安备 11010502036488号