医疗诊断模型的训练

题意

给定一个简单的"线性映射 + 线性分类"模型,需要完成三步计算:

  1. 前向传播:输入矩阵 )经过映射矩阵 )得到隐层 ,再经过分类矩阵 )得到打分 ,对 条记录取平均得到预测向量 维)。
  2. 计算 MSE 损失
  3. SGD 更新:用学习率 各做一步梯度下降。

输出预测值、损失值、更新后的两个权重矩阵,保留 2 位小数。

思路

前向传播就是简单的矩阵乘法加平均,关键在于反向传播的梯度推导。

,其中

MSE 对 的梯度:

$$

由于 是对 按行取平均,梯度均匀分配到每一行:

$$

然后按矩阵乘法的链式法则:

$$

$$

$$

最后用 SGD 更新:

整个过程就是手写一遍神经网络的前向和反向传播,没有激活函数,纯线性运算。

时间复杂度 ,空间复杂度

代码

import sys

def solve():
    data = sys.stdin.read().split('\n')
    line0 = data[0].split(',')
    L, D, K = int(line0[0]), int(line0[1]), int(line0[2])
    eta = float(line0[3])

    y = list(map(float, data[1].split(',')))

    X_flat = list(map(float, data[2].split(',')))
    X = [X_flat[i * D:(i + 1) * D] for i in range(L)]

    W_mlp_flat = list(map(float, data[3].split(',')))
    W_mlp = [W_mlp_flat[i * D:(i + 1) * D] for i in range(D)]

    W_cls_flat = list(map(float, data[4].split(',')))
    W_cls = [W_cls_flat[i * K:(i + 1) * K] for i in range(D)]

    # 前向传播:H = X @ W_mlp (L x D)
    H = [[sum(X[i][k] * W_mlp[k][j] for k in range(D)) for j in range(D)] for i in range(L)]

    # S = H @ W_cls (L x K)
    S = [[sum(H[i][k] * W_cls[k][j] for k in range(D)) for j in range(K)] for i in range(L)]

    # y_pred = mean(S, axis=0)
    y_pred = [sum(S[i][j] for i in range(L)) / L for j in range(K)]

    # MSE
    mse = sum((y_pred[j] - y[j]) ** 2 for j in range(K)) / K

    # 反向传播
    d_ypred = [2.0 * (y_pred[j] - y[j]) / K for j in range(K)]
    d_S = [[d_ypred[j] / L for j in range(K)] for _ in range(L)]

    d_W_cls = [[sum(H[l][i] * d_S[l][j] for l in range(L)) for j in range(K)] for i in range(D)]
    d_H = [[sum(d_S[i][k] * W_cls[j][k] for k in range(K)) for j in range(D)] for i in range(L)]
    d_W_mlp = [[sum(X[l][i] * d_H[l][j] for l in range(L)) for j in range(D)] for i in range(D)]

    # SGD 更新
    for i in range(D):
        for j in range(D):
            W_mlp[i][j] -= eta * d_W_mlp[i][j]
    for i in range(D):
        for j in range(K):
            W_cls[i][j] -= eta * d_W_cls[i][j]

    # 输出
    print(','.join(f'{v:.2f}' for v in y_pred))
    print(f'{mse:.2f}')
    print(','.join(f'{W_mlp[i][j]:.2f}' for i in range(D) for j in range(D)))
    print(','.join(f'{W_cls[i][j]:.2f}' for i in range(D) for j in range(K)))

solve()

复杂度分析

  • 时间复杂度,主要是矩阵乘法的开销。
  • 空间复杂度,存储中间矩阵 及梯度。