REALHW88 多目标推荐排序模型优化
题目链接
题目描述
在推荐排序的双目标场景中,需要同时预测点击率(CTR)与转化率(CVR)。用一个共享的线性权重向量 w 提取通用特征,同时为两个任务各配置一个偏置 b_ctr、b_cvr。
给定特征矩阵 X 与标签矩阵 Y(每行形如 [ctr, cvr]),从全零参数出发,按批量梯度下降迭代 N 次,学习率为 lr。训练完成后,用最终参数重新计算一次联合损失:
- 预测:
y_hat_ctr = X·w + b_ctr,y_hat_cvr = X·w + b_cvr MSE_ctr与MSE_cvr为对应任务的均方误差- 联合损失:
Loss = MSE_ctr + alpha × MSE_cvr - 输出:将
Loss × 10^10按“四舍五入(Half Up)”取整为整数
解题思路
这是一个精确模拟批量梯度下降(Batch Gradient Descent)过程的数值计算问题。我们需要严格按照题目定义的公式,一步步实现数据解析、模型训练和最终的损失计算。
1. 数据解析
首先,需要编写辅助函数将输入的字符串(如 "a,b;c,d;...")解析为二维的浮点数矩阵。可以先按 ; 分割行,再按 , 分割每行中的元素。
2. 参数初始化
- 设样本数量为
m(X的行数),特征维度为d(X的列数)。 - 权重向量
w初始化为长度为d的零向量。 - 偏置
b_ctr和b_cvr初始化为0。
3. 批量梯度下降
这是算法的核心。我们需要迭代 N 次,在每一次迭代中更新参数。
3.1. 梯度推导
联合损失函数为:
我们需要计算 Loss 对 w, b_ctr, b_cvr 的偏导数(梯度):
- 对
b_ctr的梯度: - 对
b_cvr的梯度: - 对
w的梯度 (是一个向量):
3.2. 训练循环 (迭代 N 次)
对于每一次迭代:
- 基于当前的
w,b_ctr,b_cvr,计算所有样本的预测值y_hat_ctr和y_hat_cvr。 - 计算 CTR 和 CVR 的预测误差
error_ctr = y_hat_ctr - y_ctr和error_cvr = y_hat_cvr - y_cvr。 - 根据上面推导出的公式,计算梯度
grad_w,grad_b_ctr,grad_b_cvr。 - 使用梯度下降更新参数:
w = w - lr * grad_wb_ctr = b_ctr - lr * grad_b_ctrb_cvr = b_cvr - lr * grad_b_cvr
4. 最终损失计算与输出
- 迭代
N次后(如果N=0则直接使用初始参数),使用最终的w,b_ctr,b_cvr重新计算一次所有样本的y_hat_ctr和y_hat_cvr。 - 计算
MSE_ctr和MSE_cvr。 - 计算最终的联合损失
Loss = MSE_ctr + alpha * MSE_cvr。 - 将
Loss乘以10^10,然后进行四舍五入,并输出结果。为保证精度,这一步在 Java 中推荐使用BigDecimal。
代码
#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <cmath>
#include <iomanip>
using namespace std;
// 解析字符串到二维向量
vector<vector<double>> parse_matrix(const string& s) {
vector<vector<double>> matrix;
stringstream ss(s);
string row_str;
while (getline(ss, row_str, ';')) {
vector<double> row;
stringstream row_ss(row_str);
string val_str;
while (getline(row_ss, val_str, ',')) {
row.push_back(stod(val_str));
}
matrix.push_back(row);
}
return matrix;
}
int main() {
string x_str, y_str;
int n;
double lr, alpha;
cin >> x_str >> y_str >> n >> lr >> alpha;
vector<vector<double>> X = parse_matrix(x_str);
vector<vector<double>> Y = parse_matrix(y_str);
int m = X.size();
int d = X[0].size();
vector<double> w(d, 0.0);
double b_ctr = 0.0, b_cvr = 0.0;
for (int iter = 0; iter < n; ++iter) {
vector<double> y_hat_ctr(m), y_hat_cvr(m);
for (int i = 0; i < m; ++i) {
double dot_product = 0.0;
for (int j = 0; j < d; ++j) {
dot_product += X[i][j] * w[j];
}
y_hat_ctr[i] = dot_product + b_ctr;
y_hat_cvr[i] = dot_product + b_cvr;
}
vector<double> error_ctr(m), error_cvr(m);
for (int i = 0; i < m; ++i) {
error_ctr[i] = y_hat_ctr[i] - Y[i][0];
error_cvr[i] = y_hat_cvr[i] - Y[i][1];
}
vector<double> grad_w(d, 0.0);
double grad_b_ctr = 0.0;
double grad_b_cvr = 0.0;
for (int i = 0; i < m; ++i) {
grad_b_ctr += error_ctr[i];
grad_b_cvr += error_cvr[i];
for (int j = 0; j < d; ++j) {
grad_w[j] += error_ctr[i] * X[i][j] + alpha * error_cvr[i] * X[i][j];
}
}
for (int j = 0; j < d; ++j) {
w[j] -= lr * (2.0 / m) * grad_w[j];
}
b_ctr -= lr * (2.0 / m) * grad_b_ctr;
b_cvr -= lr * (2.0 * alpha / m) * grad_b_cvr;
}
double mse_ctr = 0.0, mse_cvr = 0.0;
for (int i = 0; i < m; ++i) {
double dot_product = 0.0;
for (int j = 0; j < d; ++j) {
dot_product += X[i][j] * w[j];
}
double y_hat_ctr_final = dot_product + b_ctr;
double y_hat_cvr_final = dot_product + b_cvr;
mse_ctr += pow(y_hat_ctr_final - Y[i][0], 2);
mse_cvr += pow(y_hat_cvr_final - Y[i][1], 2);
}
mse_ctr /= m;
mse_cvr /= m;
double loss = mse_ctr + alpha * mse_cvr;
long long result = round(loss * 1e10);
cout << result << endl;
return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.List;
import java.math.BigDecimal;
import java.math.RoundingMode;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
String xStr = sc.next();
String yStr = sc.next();
int nIter = sc.nextInt();
double lr = sc.nextDouble();
double alpha = sc.nextDouble();
double[][] X = parseMatrix(xStr);
double[][] Y = parseMatrix(yStr);
int m = X.length;
int d = X[0].length;
double[] w = new double[d]; // Initialized to 0.0
double b_ctr = 0.0;
double b_cvr = 0.0;
for (int iter = 0; iter < nIter; iter++) {
double[] y_hat_ctr = new double[m];
double[] y_hat_cvr = new double[m];
for (int i = 0; i < m; i++) {
double dotProduct = 0;
for (int j = 0; j < d; j++) {
dotProduct += X[i][j] * w[j];
}
y_hat_ctr[i] = dotProduct + b_ctr;
y_hat_cvr[i] = dotProduct + b_cvr;
}
double[] error_ctr = new double[m];
double[] error_cvr = new double[m];
for (int i = 0; i < m; i++) {
error_ctr[i] = y_hat_ctr[i] - Y[i][0];
error_cvr[i] = y_hat_cvr[i] - Y[i][1];
}
double[] grad_w = new double[d];
double grad_b_ctr = 0;
double grad_b_cvr = 0;
for (int i = 0; i < m; i++) {
grad_b_ctr += error_ctr[i];
grad_b_cvr += error_cvr[i];
for (int j = 0; j < d; j++) {
grad_w[j] += error_ctr[i] * X[i][j] + alpha * error_cvr[i] * X[i][j];
}
}
for (int j = 0; j < d; j++) {
w[j] -= lr * (2.0 / m) * grad_w[j];
}
b_ctr -= lr * (2.0 / m) * grad_b_ctr;
b_cvr -= lr * (2.0 * alpha / m) * grad_b_cvr;
}
double mse_ctr = 0;
double mse_cvr = 0;
for (int i = 0; i < m; i++) {
double dotProduct = 0;
for (int j = 0; j < d; j++) {
dotProduct += X[i][j] * w[j];
}
double y_hat_ctr_final = dotProduct + b_ctr;
double y_hat_cvr_final = dotProduct + b_cvr;
mse_ctr += Math.pow(y_hat_ctr_final - Y[i][0], 2);
mse_cvr += Math.pow(y_hat_cvr_final - Y[i][1], 2);
}
mse_ctr /= m;
mse_cvr /= m;
double loss = mse_ctr + alpha * mse_cvr;
BigDecimal lossBd = new BigDecimal(loss);
BigDecimal factor = new BigDecimal("1E10");
long result = lossBd.multiply(factor).setScale(0, RoundingMode.HALF_UP).longValue();
System.out.println(result);
}
private static double[][] parseMatrix(String s) {
String[] rows = s.split(";");
List<double[]> matrixList = new ArrayList<>();
for (String rowStr : rows) {
String[] vals = rowStr.split(",");
double[] row = new double[vals.length];
for (int i = 0; i < vals.length; i++) {
row[i] = Double.parseDouble(vals[i]);
}
matrixList.add(row);
}
return matrixList.toArray(new double[0][]);
}
}
def solve():
x_str = input()
y_str = input()
n_iter = int(input())
lr = float(input())
alpha = float(input())
X = [[float(v) for v in row.split(',')] for row in x_str.split(';')]
Y = [[float(v) for v in row.split(',')] for row in y_str.split(';')]
m = len(X)
d = len(X[0])
w = [0.0] * d
b_ctr = 0.0
b_cvr = 0.0
for _ in range(n_iter):
y_hat_ctr = [0.0] * m
y_hat_cvr = [0.0] * m
for i in range(m):
dot_product = sum(X[i][j] * w[j] for j in range(d))
y_hat_ctr[i] = dot_product + b_ctr
y_hat_cvr[i] = dot_product + b_cvr
error_ctr = [(y_hat_ctr[i] - Y[i][0]) for i in range(m)]
error_cvr = [(y_hat_cvr[i] - Y[i][1]) for i in range(m)]
grad_w = [0.0] * d
grad_b_ctr = sum(error_ctr)
grad_b_cvr = sum(error_cvr)
for j in range(d):
s = 0
for i in range(m):
s += error_ctr[i] * X[i][j] + alpha * error_cvr[i] * X[i][j]
grad_w[j] = s
for j in range(d):
w[j] -= lr * (2.0 / m) * grad_w[j]
b_ctr -= lr * (2.0 / m) * grad_b_ctr
b_cvr -= lr * (2.0 * alpha / m) * grad_b_cvr
mse_ctr = 0.0
mse_cvr = 0.0
for i in range(m):
dot_product = sum(X[i][j] * w[j] for j in range(d))
y_hat_ctr_final = dot_product + b_ctr
y_hat_cvr_final = dot_product + b_cvr
mse_ctr += (y_hat_ctr_final - Y[i][0]) ** 2
mse_cvr += (y_hat_cvr_final - Y[i][1]) ** 2
mse_ctr /= m
mse_cvr /= m
loss = mse_ctr + alpha * mse_cvr
result = round(loss * 1e10)
print(result)
solve()
算法及复杂度
- 算法: 批量梯度下降 (Batch Gradient Descent) 模拟
- 时间复杂度:
,其中
是迭代次数,
是样本数量,
是特征维度。每次迭代都需要计算所有样本的预测值和梯度,其中矩阵-向量乘法是主要开销,为
。这个过程重复
次。
- 空间复杂度:
,主要用于存储输入的特征矩阵
X和标签矩阵Y。

京公网安备 11010502036488号