REALHW94 医疗诊断模型的训练

题目链接

医疗诊断模型的训练

题目描述

某医疗系统要用一次“线性映射 + 线性分类”结构对问卷症状序列做三步计算:前向预测、MSE 损失、一次 SGD 权重更新。

设一条问卷包含 条症状记录,每条症状是 维向量。先用一个 的权重矩阵把每条症状做线性变换,再用一个 的权重矩阵得到 维分类打分。把所有记录的打分在“症状条目维度”求平均,得到最终的 维预测向量(不做 softmax 归一化)。

随后与给定的 维真实向量做 MSE 损失,并用学习率 进行一次 SGD 更新这两个权重矩阵(均无偏置)。

输入描述:

  • 输入第 1 行:
  • 第 2 行:真实向量 个数)
  • 第 3 行:序列矩阵 (按行展平,共 个数)
  • 第 4 行:映射矩阵 (按行展平,共 个数)
  • 第 5 行:分类矩阵 (按行展平,共 个数)

输出描述:

  • 输出共 4 行,均为行优先展平与输出,四舍入保留 2 位小数:
    1. 个数)
    2. (1 个数)
    3. 更新后的 个数)
    4. 更新后的 个数)

计算规则:

  • (逐行相乘)
  • 参数更新:

解题思路

本题要求我们严格按照给定的计算规则,模拟一次神经网络的训练步骤,包括前向传播、损失计算和反向传播(梯度下降更新权重)。这是一个纯粹的数值模拟问题,核心在于正确实现矩阵和向量的各种运算。

整体流程可以分解为以下几个主要步骤:

  1. 数据读取与初始化

    • 读取 四个标量参数。
    • 读取真实向量
    • 读取并重构输入矩阵 (从一维数组到 的二维矩阵)。
    • 读取并重构权重矩阵 )和 )。
  2. 前向传播 (Forward Pass)

    • 计算 。这是一个矩阵乘法。对于 的每一行 (一个 维向量),计算 。结果 是一个 的矩阵。
    • 计算 :对矩阵 的所有行向量求按位的平均值,得到一个 维的向量
    • 计算 。这是一个向量-矩阵乘法,结果 是一个 维的预测向量。
  3. 损失计算 (Loss Calculation)

    • 计算 :根据均方误差公式 计算损失值。
  4. 反向传播与权重更新 (Backward Pass & Weight Update)

    • 计算梯度 ,这是损失函数对预测向量 的梯度。
    • 计算 的外积。结果是一个 的矩阵,其中
    • 计算 :为了将梯度反向传播到 ,首先计算 。这是一个向量-矩阵乘法,其中 的转置矩阵。
    • 计算 :与计算 类似,对输入矩阵 的所有行向量求按位的平均值,得到 维向量
    • 计算 的外积,结果是一个 的矩阵。
    • 更新权重:根据梯度下降法更新两个权重矩阵:
  5. 格式化输出

    • 将计算得到的 以及更新后的 按要求展平成一维,并四舍五入保留两位小数进行输出。

整个过程不涉及复杂的算法,但需要细致地处理数据结构和数值计算,特别是矩阵乘法、转置和外积等操作。

代码实现

#include <iostream>
#include <vector>
#include <iomanip>
#include <cmath>
#include <string>
#include <sstream>
#include <algorithm>

using namespace std;

// 用于打印向量(展平矩阵)
void print_vector(const vector<double>& vec) {
    for (size_t i = 0; i < vec.size(); ++i) {
        cout << (i > 0 ? "," : "") << vec[i];
    }
    cout << endl;
}

// 用于打印矩阵
void print_matrix(const vector<vector<double>>& matrix) {
    bool first = true;
    for (const auto& row : matrix) {
        for (double val : row) {
            if (!first) {
                cout << ",";
            }
            cout << val;
            first = false;
        }
    }
    cout << endl;
}

int main() {
    string line;
    stringstream ss;

    // 读取 L, D, K, eta
    getline(cin, line);
    replace(line.begin(), line.end(), ',', ' ');
    ss.str(line);
    int L, D, K;
    double eta;
    ss >> L >> D >> K >> eta;
    ss.clear();

    // 读取真实向量 y
    getline(cin, line);
    replace(line.begin(), line.end(), ',', ' ');
    ss.str(line);
    vector<double> y(K);
    for (int i = 0; i < K; ++i) ss >> y[i];
    ss.clear();

    // 读取序列矩阵 X (已展平)
    getline(cin, line);
    replace(line.begin(), line.end(), ',', ' ');
    ss.str(line);
    vector<vector<double>> X(L, vector<double>(D));
    for (int i = 0; i < L; ++i) {
        for (int j = 0; j < D; ++j) {
            ss >> X[i][j];
        }
    }
    ss.clear();

    // 读取映射矩阵 W_mlp (已展平)
    getline(cin, line);
    replace(line.begin(), line.end(), ',', ' ');
    ss.str(line);
    vector<vector<double>> W_mlp(D, vector<double>(D));
    for (int i = 0; i < D; ++i) {
        for (int j = 0; j < D; ++j) {
            ss >> W_mlp[i][j];
        }
    }
    ss.clear();

    // 读取分类矩阵 W_cls (已展平)
    getline(cin, line);
    replace(line.begin(), line.end(), ',', ' ');
    ss.str(line);
    vector<vector<double>> W_cls(D, vector<double>(K));
    for (int i = 0; i < D; ++i) {
        for (int j = 0; j < K; ++j) {
            ss >> W_cls[i][j];
        }
    }
    ss.clear();

    // 步骤1:前向传播
    // 计算 H = X @ W_mlp
    vector<vector<double>> H(L, vector<double>(D, 0.0));
    for (int i = 0; i < L; ++i) {
        for (int j = 0; j < D; ++j) {
            for (int k = 0; k < D; ++k) {
                H[i][j] += X[i][k] * W_mlp[k][j];
            }
        }
    }

    // 计算 h_mean
    vector<double> h_mean(D, 0.0);
    for (int i = 0; i < L; ++i) {
        for (int j = 0; j < D; ++j) {
            h_mean[j] += H[i][j];
        }
    }
    for (int j = 0; j < D; ++j) {
        h_mean[j] /= L;
    }

    // 计算 y_pred = h_mean @ W_cls
    vector<double> y_pred(K, 0.0);
    for (int j = 0; j < K; ++j) {
        for (int i = 0; i < D; ++i) {
            y_pred[j] += h_mean[i] * W_cls[i][j];
        }
    }

    // 步骤2:计算 MSE 损失
    double mse = 0.0;
    for (int i = 0; i < K; ++i) {
        mse += pow(y_pred[i] - y[i], 2);
    }
    mse /= K;

    // 步骤3:反向传播与权重更新
    // 计算 g
    vector<double> g(K);
    for (int i = 0; i < K; ++i) {
        g[i] = (2.0 / K) * (y_pred[i] - y[i]);
    }

    // 计算 grad_W_cls = 外积(h_mean, g)
    vector<vector<double>> grad_W_cls(D, vector<double>(K));
    for (int i = 0; i < D; ++i) {
        for (int j = 0; j < K; ++j) {
            grad_W_cls[i][j] = h_mean[i] * g[j];
        }
    }
    
    // 计算 u = g @ W_cls^T
    vector<double> u(D, 0.0);
    for (int i = 0; i < D; ++i) {
        for (int j = 0; j < K; ++j) {
            u[i] += g[j] * W_cls[i][j];
        }
    }

    // 计算 x_mean
    vector<double> x_mean(D, 0.0);
    for (int i = 0; i < L; ++i) {
        for (int j = 0; j < D; ++j) {
            x_mean[j] += X[i][j];
        }
    }
    for (int j = 0; j < D; ++j) {
        x_mean[j] /= L;
    }

    // 计算 grad_W_mlp = 外积(x_mean, u)
    vector<vector<double>> grad_W_mlp(D, vector<double>(D));
    for (int i = 0; i < D; ++i) {
        for (int j = 0; j < D; ++j) {
            grad_W_mlp[i][j] = x_mean[i] * u[j];
        }
    }

    // 更新权重
    for (int i = 0; i < D; ++i) {
        for (int j = 0; j < D; ++j) {
            W_mlp[i][j] -= eta * grad_W_mlp[i][j];
        }
    }
    for (int i = 0; i < D; ++i) {
        for (int j = 0; j < K; ++j) {
            W_cls[i][j] -= eta * grad_W_cls[i][j];
        }
    }

    // 步骤4:输出结果
    cout << fixed << setprecision(2);
    print_vector(y_pred);
    cout << mse << endl;
    print_matrix(W_mlp);
    print_matrix(W_cls);

    return 0;
}
import java.util.Scanner;
import java.util.Locale;

public class Main {
    public static void main(String[] args) {
        // 使用 Scanner sc 进行输入,并设置分隔符为逗号或空白
        Scanner sc = new Scanner(System.in).useDelimiter("[\\s,]+");

        // 读取 L, D, K, eta
        int L = sc.nextInt();
        int D = sc.nextInt();
        int K = sc.nextInt();
        double eta = sc.nextDouble();

        // 读取真实向量 y
        double[] y = new double[K];
        for (int i = 0; i < K; i++) {
            y[i] = sc.nextDouble();
        }

        // 读取序列矩阵 X
        double[][] X = new double[L][D];
        for (int i = 0; i < L; i++) {
            for (int j = 0; j < D; j++) {
                X[i][j] = sc.nextDouble();
            }
        }

        // 读取映射矩阵 W_mlp
        double[][] W_mlp = new double[D][D];
        for (int i = 0; i < D; i++) {
            for (int j = 0; j < D; j++) {
                W_mlp[i][j] = sc.nextDouble();
            }
        }

        // 读取分类矩阵 W_cls
        double[][] W_cls = new double[D][K];
        for (int i = 0; i < D; i++) {
            for (int j = 0; j < K; j++) {
                W_cls[i][j] = sc.nextDouble();
            }
        }

        // 步骤1:前向传播
        // 计算 H = X @ W_mlp
        double[][] H = new double[L][D];
        for (int i = 0; i < L; i++) {
            for (int j = 0; j < D; j++) {
                for (int k = 0; k < D; k++) {
                    H[i][j] += X[i][k] * W_mlp[k][j];
                }
            }
        }

        // 计算 h_mean
        double[] h_mean = new double[D];
        for (int i = 0; i < L; i++) {
            for (int j = 0; j < D; j++) {
                h_mean[j] += H[i][j];
            }
        }
        for (int j = 0; j < D; j++) {
            h_mean[j] /= L;
        }

        // 计算 y_pred = h_mean @ W_cls
        double[] y_pred = new double[K];
        for (int j = 0; j < K; j++) {
            for (int i = 0; i < D; i++) {
                y_pred[j] += h_mean[i] * W_cls[i][j];
            }
        }

        // 步骤2:计算 MSE 损失
        double mse = 0.0;
        for (int i = 0; i < K; i++) {
            mse += Math.pow(y_pred[i] - y[i], 2);
        }
        mse /= K;

        // 步骤3:反向传播与权重更新
        // 计算 g
        double[] g = new double[K];
        for (int i = 0; i < K; i++) {
            g[i] = (2.0 / K) * (y_pred[i] - y[i]);
        }

        // 计算 grad_W_cls = 外积(h_mean, g)
        double[][] grad_W_cls = new double[D][K];
        for (int i = 0; i < D; i++) {
            for (int j = 0; j < K; j++) {
                grad_W_cls[i][j] = h_mean[i] * g[j];
            }
        }

        // 计算 u = g @ W_cls^T
        double[] u = new double[D];
        for (int i = 0; i < D; i++) {
            for (int j = 0; j < K; j++) {
                u[i] += g[j] * W_cls[i][j];
            }
        }

        // 计算 x_mean
        double[] x_mean = new double[D];
        for (int i = 0; i < L; i++) {
            for (int j = 0; j < D; j++) {
                x_mean[j] += X[i][j];
            }
        }
        for (int j = 0; j < D; j++) {
            x_mean[j] /= L;
        }

        // 计算 grad_W_mlp = 外积(x_mean, u)
        double[][] grad_W_mlp = new double[D][D];
        for (int i = 0; i < D; i++) {
            for (int j = 0; j < D; j++) {
                grad_W_mlp[i][j] = x_mean[i] * u[j];
            }
        }

        // 更新权重
        for (int i = 0; i < D; i++) {
            for (int j = 0; j < D; j++) {
                W_mlp[i][j] -= eta * grad_W_mlp[i][j];
            }
        }
        for (int i = 0; i < D; i++) {
            for (int j = 0; j < K; j++) {
                W_cls[i][j] -= eta * grad_W_cls[i][j];
            }
        }

        // 步骤4:输出结果
        for (int i = 0; i < K; i++) {
            System.out.printf(Locale.US, "%.2f%s", y_pred[i], i == K - 1 ? "" : ",");
        }
        System.out.println();
        System.out.printf(Locale.US, "%.2f\n", mse);
        for (int i = 0; i < D; i++) {
            for (int j = 0; j < D; j++) {
                 System.out.printf(Locale.US, "%.2f%s", W_mlp[i][j], (i == D - 1 && j == D - 1) ? "" : ",");
            }
        }
        System.out.println();
        for (int i = 0; i < D; i++) {
            for (int j = 0; j < K; j++) {
                 System.out.printf(Locale.US, "%.2f%s", W_cls[i][j], (i == D - 1 && j == K - 1) ? "" : ",");
            }
        }
        System.out.println();
    }
}
def main():
    # 读取 L, D, K, eta
    line1 = input().split(',')
    L, D, K = int(line1[0]), int(line1[1]), int(line1[2])
    eta = float(line1[3])

    # 读取真实向量 y
    y = list(map(float, input().split(',')))

    # 读取序列矩阵 X
    x_flat = list(map(float, input().split(',')))
    X = [x_flat[i * D:(i + 1) * D] for i in range(L)]

    # 读取映射矩阵 W_mlp
    w_mlp_flat = list(map(float, input().split(',')))
    W_mlp = [w_mlp_flat[i * D:(i + 1) * D] for i in range(D)]

    # 读取分类矩阵 W_cls
    w_cls_flat = list(map(float, input().split(',')))
    W_cls = [w_cls_flat[i * K:(i + 1) * K] for i in range(D)]

    # 步骤1:前向传播
    # 计算 H = X @ W_mlp
    H = [[0.0] * D for _ in range(L)]
    for i in range(L):
        for j in range(D):
            for k in range(D):
                H[i][j] += X[i][k] * W_mlp[k][j]

    # 计算 h_mean
    h_mean = [0.0] * D
    for i in range(L):
        for j in range(D):
            h_mean[j] += H[i][j]
    for j in range(D):
        h_mean[j] /= L

    # 计算 y_pred = h_mean @ W_cls
    y_pred = [0.0] * K
    for j in range(K):
        for i in range(D):
            y_pred[j] += h_mean[i] * W_cls[i][j]

    # 步骤2:计算 MSE 损失
    mse = sum([(y_pred[i] - y[i]) ** 2 for i in range(K)]) / K

    # 步骤3:反向传播与权重更新
    # 计算 g
    g = [(2.0 / K) * (y_pred[i] - y[i]) for i in range(K)]

    # 计算 grad_W_cls = 外积(h_mean, g)
    grad_W_cls = [[h_mean[i] * g[j] for j in range(K)] for i in range(D)]

    # 计算 u = g @ W_cls^T
    u = [0.0] * D
    for i in range(D):
        for j in range(K):
            u[i] += g[j] * W_cls[i][j]
    
    # 计算 x_mean
    x_mean = [0.0] * D
    for i in range(L):
        for j in range(D):
            x_mean[j] += X[i][j]
    for j in range(D):
        x_mean[j] /= L

    # 计算 grad_W_mlp = 外积(x_mean, u)
    grad_W_mlp = [[x_mean[i] * u[j] for j in range(D)] for i in range(D)]

    # 更新权重
    for i in range(D):
        for j in range(D):
            W_mlp[i][j] -= eta * grad_W_mlp[i][j]
    
    for i in range(D):
        for j in range(K):
            W_cls[i][j] -= eta * grad_W_cls[i][j]

    # 步骤4:输出结果
    print(",".join([f"{v:.2f}" for v in y_pred]))
    print(f"{mse:.2f}")
    
    w_mlp_flat_new = [item for sublist in W_mlp for item in sublist]
    print(",".join([f"{v:.2f}" for v in w_mlp_flat_new]))

    w_cls_flat_new = [item for sublist in W_cls for item in sublist]
    print(",".join([f"{v:.2f}" for v in w_cls_flat_new]))

if __name__ == "__main__":
    main()

算法及复杂度

  • 时间复杂度:算法的主要计算开销在于矩阵乘法。

    • 计算 需要
    • 计算 需要
    • 计算 需要
    • 计算 需要
    • 其他操作如梯度计算(外积)和权重更新的复杂度都低于矩阵乘法。
    • 因此,总的时间复杂度由最高阶项决定,为
  • 空间复杂度:算法需要存储所有输入的矩阵和向量,以及计算过程中产生的中间矩阵和向量。

    • 输入矩阵
    • 权重矩阵
    • 中间矩阵
    • 梯度矩阵
    • 因此,总的空间复杂度为