REALHW94 医疗诊断模型的训练
题目链接
题目描述
某医疗系统要用一次“线性映射 + 线性分类”结构对问卷症状序列做三步计算:前向预测、MSE 损失、一次 SGD 权重更新。
设一条问卷包含 条症状记录,每条症状是
维向量。先用一个
的权重矩阵把每条症状做线性变换,再用一个
的权重矩阵得到
维分类打分。把所有记录的打分在“症状条目维度”求平均,得到最终的
维预测向量(不做 softmax 归一化)。
随后与给定的 维真实向量做 MSE 损失,并用学习率
进行一次 SGD 更新这两个权重矩阵(均无偏置)。
输入描述:
- 输入第 1 行:
- 第 2 行:真实向量
(
个数)
- 第 3 行:序列矩阵
(按行展平,共
个数)
- 第 4 行:映射矩阵
(按行展平,共
个数)
- 第 5 行:分类矩阵
(按行展平,共
个数)
输出描述:
- 输出共 4 行,均为行优先展平与输出,四舍入保留 2 位小数:
(
个数)
(1 个数)
- 更新后的
(
个数)
- 更新后的
(
个数)
计算规则:
(逐行相乘)
(
)
(
)
- 令
(
)
(
)
- 令
(
)
(
)
(
)
- 参数更新:
,
解题思路
本题要求我们严格按照给定的计算规则,模拟一次神经网络的训练步骤,包括前向传播、损失计算和反向传播(梯度下降更新权重)。这是一个纯粹的数值模拟问题,核心在于正确实现矩阵和向量的各种运算。
整体流程可以分解为以下几个主要步骤:
-
数据读取与初始化:
- 读取
四个标量参数。
- 读取真实向量
。
- 读取并重构输入矩阵
(从一维数组到
的二维矩阵)。
- 读取并重构权重矩阵
(
)和
(
)。
- 读取
-
前向传播 (Forward Pass):
- 计算
:
。这是一个矩阵乘法。对于
的每一行
(一个
维向量),计算
。结果
是一个
的矩阵。
- 计算
:对矩阵
的所有行向量求按位的平均值,得到一个
维的向量
。
- 计算
:
。这是一个向量-矩阵乘法,结果
是一个
维的预测向量。
- 计算
-
损失计算 (Loss Calculation):
- 计算
:根据均方误差公式
计算损失值。
- 计算
-
反向传播与权重更新 (Backward Pass & Weight Update):
- 计算梯度
:
,这是损失函数对预测向量
的梯度。
- 计算
:
是
和
的外积。结果是一个
的矩阵,其中
。
- 计算
:为了将梯度反向传播到
,首先计算
。这是一个向量-矩阵乘法,其中
是
的转置矩阵。
- 计算
:与计算
类似,对输入矩阵
的所有行向量求按位的平均值,得到
维向量
。
- 计算
:
是
和
的外积,结果是一个
的矩阵。
- 更新权重:根据梯度下降法更新两个权重矩阵:
- 计算梯度
-
格式化输出:
- 将计算得到的
、
以及更新后的
和
按要求展平成一维,并四舍五入保留两位小数进行输出。
- 将计算得到的
整个过程不涉及复杂的算法,但需要细致地处理数据结构和数值计算,特别是矩阵乘法、转置和外积等操作。
代码实现
#include <iostream>
#include <vector>
#include <iomanip>
#include <cmath>
#include <string>
#include <sstream>
#include <algorithm>
using namespace std;
// 用于打印向量(展平矩阵)
void print_vector(const vector<double>& vec) {
for (size_t i = 0; i < vec.size(); ++i) {
cout << (i > 0 ? "," : "") << vec[i];
}
cout << endl;
}
// 用于打印矩阵
void print_matrix(const vector<vector<double>>& matrix) {
bool first = true;
for (const auto& row : matrix) {
for (double val : row) {
if (!first) {
cout << ",";
}
cout << val;
first = false;
}
}
cout << endl;
}
int main() {
string line;
stringstream ss;
// 读取 L, D, K, eta
getline(cin, line);
replace(line.begin(), line.end(), ',', ' ');
ss.str(line);
int L, D, K;
double eta;
ss >> L >> D >> K >> eta;
ss.clear();
// 读取真实向量 y
getline(cin, line);
replace(line.begin(), line.end(), ',', ' ');
ss.str(line);
vector<double> y(K);
for (int i = 0; i < K; ++i) ss >> y[i];
ss.clear();
// 读取序列矩阵 X (已展平)
getline(cin, line);
replace(line.begin(), line.end(), ',', ' ');
ss.str(line);
vector<vector<double>> X(L, vector<double>(D));
for (int i = 0; i < L; ++i) {
for (int j = 0; j < D; ++j) {
ss >> X[i][j];
}
}
ss.clear();
// 读取映射矩阵 W_mlp (已展平)
getline(cin, line);
replace(line.begin(), line.end(), ',', ' ');
ss.str(line);
vector<vector<double>> W_mlp(D, vector<double>(D));
for (int i = 0; i < D; ++i) {
for (int j = 0; j < D; ++j) {
ss >> W_mlp[i][j];
}
}
ss.clear();
// 读取分类矩阵 W_cls (已展平)
getline(cin, line);
replace(line.begin(), line.end(), ',', ' ');
ss.str(line);
vector<vector<double>> W_cls(D, vector<double>(K));
for (int i = 0; i < D; ++i) {
for (int j = 0; j < K; ++j) {
ss >> W_cls[i][j];
}
}
ss.clear();
// 步骤1:前向传播
// 计算 H = X @ W_mlp
vector<vector<double>> H(L, vector<double>(D, 0.0));
for (int i = 0; i < L; ++i) {
for (int j = 0; j < D; ++j) {
for (int k = 0; k < D; ++k) {
H[i][j] += X[i][k] * W_mlp[k][j];
}
}
}
// 计算 h_mean
vector<double> h_mean(D, 0.0);
for (int i = 0; i < L; ++i) {
for (int j = 0; j < D; ++j) {
h_mean[j] += H[i][j];
}
}
for (int j = 0; j < D; ++j) {
h_mean[j] /= L;
}
// 计算 y_pred = h_mean @ W_cls
vector<double> y_pred(K, 0.0);
for (int j = 0; j < K; ++j) {
for (int i = 0; i < D; ++i) {
y_pred[j] += h_mean[i] * W_cls[i][j];
}
}
// 步骤2:计算 MSE 损失
double mse = 0.0;
for (int i = 0; i < K; ++i) {
mse += pow(y_pred[i] - y[i], 2);
}
mse /= K;
// 步骤3:反向传播与权重更新
// 计算 g
vector<double> g(K);
for (int i = 0; i < K; ++i) {
g[i] = (2.0 / K) * (y_pred[i] - y[i]);
}
// 计算 grad_W_cls = 外积(h_mean, g)
vector<vector<double>> grad_W_cls(D, vector<double>(K));
for (int i = 0; i < D; ++i) {
for (int j = 0; j < K; ++j) {
grad_W_cls[i][j] = h_mean[i] * g[j];
}
}
// 计算 u = g @ W_cls^T
vector<double> u(D, 0.0);
for (int i = 0; i < D; ++i) {
for (int j = 0; j < K; ++j) {
u[i] += g[j] * W_cls[i][j];
}
}
// 计算 x_mean
vector<double> x_mean(D, 0.0);
for (int i = 0; i < L; ++i) {
for (int j = 0; j < D; ++j) {
x_mean[j] += X[i][j];
}
}
for (int j = 0; j < D; ++j) {
x_mean[j] /= L;
}
// 计算 grad_W_mlp = 外积(x_mean, u)
vector<vector<double>> grad_W_mlp(D, vector<double>(D));
for (int i = 0; i < D; ++i) {
for (int j = 0; j < D; ++j) {
grad_W_mlp[i][j] = x_mean[i] * u[j];
}
}
// 更新权重
for (int i = 0; i < D; ++i) {
for (int j = 0; j < D; ++j) {
W_mlp[i][j] -= eta * grad_W_mlp[i][j];
}
}
for (int i = 0; i < D; ++i) {
for (int j = 0; j < K; ++j) {
W_cls[i][j] -= eta * grad_W_cls[i][j];
}
}
// 步骤4:输出结果
cout << fixed << setprecision(2);
print_vector(y_pred);
cout << mse << endl;
print_matrix(W_mlp);
print_matrix(W_cls);
return 0;
}
import java.util.Scanner;
import java.util.Locale;
public class Main {
public static void main(String[] args) {
// 使用 Scanner sc 进行输入,并设置分隔符为逗号或空白
Scanner sc = new Scanner(System.in).useDelimiter("[\\s,]+");
// 读取 L, D, K, eta
int L = sc.nextInt();
int D = sc.nextInt();
int K = sc.nextInt();
double eta = sc.nextDouble();
// 读取真实向量 y
double[] y = new double[K];
for (int i = 0; i < K; i++) {
y[i] = sc.nextDouble();
}
// 读取序列矩阵 X
double[][] X = new double[L][D];
for (int i = 0; i < L; i++) {
for (int j = 0; j < D; j++) {
X[i][j] = sc.nextDouble();
}
}
// 读取映射矩阵 W_mlp
double[][] W_mlp = new double[D][D];
for (int i = 0; i < D; i++) {
for (int j = 0; j < D; j++) {
W_mlp[i][j] = sc.nextDouble();
}
}
// 读取分类矩阵 W_cls
double[][] W_cls = new double[D][K];
for (int i = 0; i < D; i++) {
for (int j = 0; j < K; j++) {
W_cls[i][j] = sc.nextDouble();
}
}
// 步骤1:前向传播
// 计算 H = X @ W_mlp
double[][] H = new double[L][D];
for (int i = 0; i < L; i++) {
for (int j = 0; j < D; j++) {
for (int k = 0; k < D; k++) {
H[i][j] += X[i][k] * W_mlp[k][j];
}
}
}
// 计算 h_mean
double[] h_mean = new double[D];
for (int i = 0; i < L; i++) {
for (int j = 0; j < D; j++) {
h_mean[j] += H[i][j];
}
}
for (int j = 0; j < D; j++) {
h_mean[j] /= L;
}
// 计算 y_pred = h_mean @ W_cls
double[] y_pred = new double[K];
for (int j = 0; j < K; j++) {
for (int i = 0; i < D; i++) {
y_pred[j] += h_mean[i] * W_cls[i][j];
}
}
// 步骤2:计算 MSE 损失
double mse = 0.0;
for (int i = 0; i < K; i++) {
mse += Math.pow(y_pred[i] - y[i], 2);
}
mse /= K;
// 步骤3:反向传播与权重更新
// 计算 g
double[] g = new double[K];
for (int i = 0; i < K; i++) {
g[i] = (2.0 / K) * (y_pred[i] - y[i]);
}
// 计算 grad_W_cls = 外积(h_mean, g)
double[][] grad_W_cls = new double[D][K];
for (int i = 0; i < D; i++) {
for (int j = 0; j < K; j++) {
grad_W_cls[i][j] = h_mean[i] * g[j];
}
}
// 计算 u = g @ W_cls^T
double[] u = new double[D];
for (int i = 0; i < D; i++) {
for (int j = 0; j < K; j++) {
u[i] += g[j] * W_cls[i][j];
}
}
// 计算 x_mean
double[] x_mean = new double[D];
for (int i = 0; i < L; i++) {
for (int j = 0; j < D; j++) {
x_mean[j] += X[i][j];
}
}
for (int j = 0; j < D; j++) {
x_mean[j] /= L;
}
// 计算 grad_W_mlp = 外积(x_mean, u)
double[][] grad_W_mlp = new double[D][D];
for (int i = 0; i < D; i++) {
for (int j = 0; j < D; j++) {
grad_W_mlp[i][j] = x_mean[i] * u[j];
}
}
// 更新权重
for (int i = 0; i < D; i++) {
for (int j = 0; j < D; j++) {
W_mlp[i][j] -= eta * grad_W_mlp[i][j];
}
}
for (int i = 0; i < D; i++) {
for (int j = 0; j < K; j++) {
W_cls[i][j] -= eta * grad_W_cls[i][j];
}
}
// 步骤4:输出结果
for (int i = 0; i < K; i++) {
System.out.printf(Locale.US, "%.2f%s", y_pred[i], i == K - 1 ? "" : ",");
}
System.out.println();
System.out.printf(Locale.US, "%.2f\n", mse);
for (int i = 0; i < D; i++) {
for (int j = 0; j < D; j++) {
System.out.printf(Locale.US, "%.2f%s", W_mlp[i][j], (i == D - 1 && j == D - 1) ? "" : ",");
}
}
System.out.println();
for (int i = 0; i < D; i++) {
for (int j = 0; j < K; j++) {
System.out.printf(Locale.US, "%.2f%s", W_cls[i][j], (i == D - 1 && j == K - 1) ? "" : ",");
}
}
System.out.println();
}
}
def main():
# 读取 L, D, K, eta
line1 = input().split(',')
L, D, K = int(line1[0]), int(line1[1]), int(line1[2])
eta = float(line1[3])
# 读取真实向量 y
y = list(map(float, input().split(',')))
# 读取序列矩阵 X
x_flat = list(map(float, input().split(',')))
X = [x_flat[i * D:(i + 1) * D] for i in range(L)]
# 读取映射矩阵 W_mlp
w_mlp_flat = list(map(float, input().split(',')))
W_mlp = [w_mlp_flat[i * D:(i + 1) * D] for i in range(D)]
# 读取分类矩阵 W_cls
w_cls_flat = list(map(float, input().split(',')))
W_cls = [w_cls_flat[i * K:(i + 1) * K] for i in range(D)]
# 步骤1:前向传播
# 计算 H = X @ W_mlp
H = [[0.0] * D for _ in range(L)]
for i in range(L):
for j in range(D):
for k in range(D):
H[i][j] += X[i][k] * W_mlp[k][j]
# 计算 h_mean
h_mean = [0.0] * D
for i in range(L):
for j in range(D):
h_mean[j] += H[i][j]
for j in range(D):
h_mean[j] /= L
# 计算 y_pred = h_mean @ W_cls
y_pred = [0.0] * K
for j in range(K):
for i in range(D):
y_pred[j] += h_mean[i] * W_cls[i][j]
# 步骤2:计算 MSE 损失
mse = sum([(y_pred[i] - y[i]) ** 2 for i in range(K)]) / K
# 步骤3:反向传播与权重更新
# 计算 g
g = [(2.0 / K) * (y_pred[i] - y[i]) for i in range(K)]
# 计算 grad_W_cls = 外积(h_mean, g)
grad_W_cls = [[h_mean[i] * g[j] for j in range(K)] for i in range(D)]
# 计算 u = g @ W_cls^T
u = [0.0] * D
for i in range(D):
for j in range(K):
u[i] += g[j] * W_cls[i][j]
# 计算 x_mean
x_mean = [0.0] * D
for i in range(L):
for j in range(D):
x_mean[j] += X[i][j]
for j in range(D):
x_mean[j] /= L
# 计算 grad_W_mlp = 外积(x_mean, u)
grad_W_mlp = [[x_mean[i] * u[j] for j in range(D)] for i in range(D)]
# 更新权重
for i in range(D):
for j in range(D):
W_mlp[i][j] -= eta * grad_W_mlp[i][j]
for i in range(D):
for j in range(K):
W_cls[i][j] -= eta * grad_W_cls[i][j]
# 步骤4:输出结果
print(",".join([f"{v:.2f}" for v in y_pred]))
print(f"{mse:.2f}")
w_mlp_flat_new = [item for sublist in W_mlp for item in sublist]
print(",".join([f"{v:.2f}" for v in w_mlp_flat_new]))
w_cls_flat_new = [item for sublist in W_cls for item in sublist]
print(",".join([f"{v:.2f}" for v in w_cls_flat_new]))
if __name__ == "__main__":
main()
算法及复杂度
-
时间复杂度:算法的主要计算开销在于矩阵乘法。
- 计算
需要
。
- 计算
和
需要
。
- 计算
需要
。
- 计算
需要
。
- 其他操作如梯度计算(外积)和权重更新的复杂度都低于矩阵乘法。
- 因此,总的时间复杂度由最高阶项决定,为
。
- 计算
-
空间复杂度:算法需要存储所有输入的矩阵和向量,以及计算过程中产生的中间矩阵和向量。
- 输入矩阵
占
。
- 权重矩阵
占
,
占
。
- 中间矩阵
占
。
- 梯度矩阵
占
,
占
。
- 因此,总的空间复杂度为
。
- 输入矩阵

京公网安备 11010502036488号