REALHW88 多目标推荐排序模型优化

题目链接

多目标推荐排序模型优化

题目描述

在推荐排序的双目标场景中,需要同时预测点击率(CTR)与转化率(CVR)。用一个共享的线性权重向量 w 提取通用特征,同时为两个任务各配置一个偏置 b_ctrb_cvr

给定特征矩阵 X 与标签矩阵 Y(每行形如 [ctr, cvr]),从全零参数出发,按批量梯度下降迭代 N 次,学习率为 lr。训练完成后,用最终参数重新计算一次联合损失:

  • 预测y_hat_ctr = X·w + b_ctry_hat_cvr = X·w + b_cvr
  • MSE_ctrMSE_cvr 为对应任务的均方误差
  • 联合损失Loss = MSE_ctr + alpha × MSE_cvr
  • 输出:将 Loss × 10^10 按“四舍五入(Half Up)”取整为整数

解题思路

这是一个精确模拟批量梯度下降(Batch Gradient Descent)过程的数值计算问题。我们需要严格按照题目定义的公式,一步步实现数据解析、模型训练和最终的损失计算。

1. 数据解析

首先,需要编写辅助函数将输入的字符串(如 "a,b;c,d;...")解析为二维的浮点数矩阵。可以先按 ; 分割行,再按 , 分割每行中的元素。

2. 参数初始化

  • 设样本数量为 mX 的行数),特征维度为 dX 的列数)。
  • 权重向量 w 初始化为长度为 d 的零向量。
  • 偏置 b_ctrb_cvr 初始化为 0

3. 批量梯度下降

这是算法的核心。我们需要迭代 N 次,在每一次迭代中更新参数。

3.1. 梯度推导

联合损失函数为:

我们需要计算 Lossw, b_ctr, b_cvr 的偏导数(梯度):

  • b_ctr 的梯度
  • b_cvr 的梯度
  • w 的梯度 (是一个向量):
3.2. 训练循环 (迭代 N 次)

对于每一次迭代:

  1. 基于当前w, b_ctr, b_cvr,计算所有样本的预测值 y_hat_ctry_hat_cvr
  2. 计算 CTR 和 CVR 的预测误差 error_ctr = y_hat_ctr - y_ctrerror_cvr = y_hat_cvr - y_cvr
  3. 根据上面推导出的公式,计算梯度 grad_w, grad_b_ctr, grad_b_cvr
  4. 使用梯度下降更新参数:
    • w = w - lr * grad_w
    • b_ctr = b_ctr - lr * grad_b_ctr
    • b_cvr = b_cvr - lr * grad_b_cvr

4. 最终损失计算与输出

  • 迭代 N 次后(如果 N=0 则直接使用初始参数),使用最终的 w, b_ctr, b_cvr 重新计算一次所有样本的 y_hat_ctry_hat_cvr
  • 计算 MSE_ctrMSE_cvr
  • 计算最终的联合损失 Loss = MSE_ctr + alpha * MSE_cvr
  • Loss 乘以 10^10,然后进行四舍五入,并输出结果。为保证精度,这一步在 Java 中推荐使用 BigDecimal

代码

#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <cmath>
#include <iomanip>

using namespace std;

// 解析字符串到二维向量
vector<vector<double>> parse_matrix(const string& s) {
    vector<vector<double>> matrix;
    stringstream ss(s);
    string row_str;
    while (getline(ss, row_str, ';')) {
        vector<double> row;
        stringstream row_ss(row_str);
        string val_str;
        while (getline(row_ss, val_str, ',')) {
            row.push_back(stod(val_str));
        }
        matrix.push_back(row);
    }
    return matrix;
}

int main() {
    string x_str, y_str;
    int n;
    double lr, alpha;

    cin >> x_str >> y_str >> n >> lr >> alpha;

    vector<vector<double>> X = parse_matrix(x_str);
    vector<vector<double>> Y = parse_matrix(y_str);

    int m = X.size();
    int d = X[0].size();

    vector<double> w(d, 0.0);
    double b_ctr = 0.0, b_cvr = 0.0;

    for (int iter = 0; iter < n; ++iter) {
        vector<double> y_hat_ctr(m), y_hat_cvr(m);
        for (int i = 0; i < m; ++i) {
            double dot_product = 0.0;
            for (int j = 0; j < d; ++j) {
                dot_product += X[i][j] * w[j];
            }
            y_hat_ctr[i] = dot_product + b_ctr;
            y_hat_cvr[i] = dot_product + b_cvr;
        }

        vector<double> error_ctr(m), error_cvr(m);
        for (int i = 0; i < m; ++i) {
            error_ctr[i] = y_hat_ctr[i] - Y[i][0];
            error_cvr[i] = y_hat_cvr[i] - Y[i][1];
        }

        vector<double> grad_w(d, 0.0);
        double grad_b_ctr = 0.0;
        double grad_b_cvr = 0.0;

        for (int i = 0; i < m; ++i) {
            grad_b_ctr += error_ctr[i];
            grad_b_cvr += error_cvr[i];
            for (int j = 0; j < d; ++j) {
                grad_w[j] += error_ctr[i] * X[i][j] + alpha * error_cvr[i] * X[i][j];
            }
        }

        for (int j = 0; j < d; ++j) {
            w[j] -= lr * (2.0 / m) * grad_w[j];
        }
        b_ctr -= lr * (2.0 / m) * grad_b_ctr;
        b_cvr -= lr * (2.0 * alpha / m) * grad_b_cvr;
    }
    
    double mse_ctr = 0.0, mse_cvr = 0.0;
    for (int i = 0; i < m; ++i) {
        double dot_product = 0.0;
        for (int j = 0; j < d; ++j) {
            dot_product += X[i][j] * w[j];
        }
        double y_hat_ctr_final = dot_product + b_ctr;
        double y_hat_cvr_final = dot_product + b_cvr;
        mse_ctr += pow(y_hat_ctr_final - Y[i][0], 2);
        mse_cvr += pow(y_hat_cvr_final - Y[i][1], 2);
    }
    mse_ctr /= m;
    mse_cvr /= m;

    double loss = mse_ctr + alpha * mse_cvr;
    long long result = round(loss * 1e10);

    cout << result << endl;

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.List;
import java.math.BigDecimal;
import java.math.RoundingMode;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        String xStr = sc.next();
        String yStr = sc.next();
        int nIter = sc.nextInt();
        double lr = sc.nextDouble();
        double alpha = sc.nextDouble();

        double[][] X = parseMatrix(xStr);
        double[][] Y = parseMatrix(yStr);

        int m = X.length;
        int d = X[0].length;

        double[] w = new double[d]; // Initialized to 0.0
        double b_ctr = 0.0;
        double b_cvr = 0.0;

        for (int iter = 0; iter < nIter; iter++) {
            double[] y_hat_ctr = new double[m];
            double[] y_hat_cvr = new double[m];
            
            for (int i = 0; i < m; i++) {
                double dotProduct = 0;
                for (int j = 0; j < d; j++) {
                    dotProduct += X[i][j] * w[j];
                }
                y_hat_ctr[i] = dotProduct + b_ctr;
                y_hat_cvr[i] = dotProduct + b_cvr;
            }

            double[] error_ctr = new double[m];
            double[] error_cvr = new double[m];
            for (int i = 0; i < m; i++) {
                error_ctr[i] = y_hat_ctr[i] - Y[i][0];
                error_cvr[i] = y_hat_cvr[i] - Y[i][1];
            }

            double[] grad_w = new double[d];
            double grad_b_ctr = 0;
            double grad_b_cvr = 0;

            for (int i = 0; i < m; i++) {
                grad_b_ctr += error_ctr[i];
                grad_b_cvr += error_cvr[i];
                for (int j = 0; j < d; j++) {
                    grad_w[j] += error_ctr[i] * X[i][j] + alpha * error_cvr[i] * X[i][j];
                }
            }

            for (int j = 0; j < d; j++) {
                w[j] -= lr * (2.0 / m) * grad_w[j];
            }
            b_ctr -= lr * (2.0 / m) * grad_b_ctr;
            b_cvr -= lr * (2.0 * alpha / m) * grad_b_cvr;
        }

        double mse_ctr = 0;
        double mse_cvr = 0;
        for (int i = 0; i < m; i++) {
            double dotProduct = 0;
            for (int j = 0; j < d; j++) {
                dotProduct += X[i][j] * w[j];
            }
            double y_hat_ctr_final = dotProduct + b_ctr;
            double y_hat_cvr_final = dotProduct + b_cvr;
            mse_ctr += Math.pow(y_hat_ctr_final - Y[i][0], 2);
            mse_cvr += Math.pow(y_hat_cvr_final - Y[i][1], 2);
        }
        mse_ctr /= m;
        mse_cvr /= m;

        double loss = mse_ctr + alpha * mse_cvr;

        BigDecimal lossBd = new BigDecimal(loss);
        BigDecimal factor = new BigDecimal("1E10");
        long result = lossBd.multiply(factor).setScale(0, RoundingMode.HALF_UP).longValue();
        
        System.out.println(result);
    }

    private static double[][] parseMatrix(String s) {
        String[] rows = s.split(";");
        List<double[]> matrixList = new ArrayList<>();
        for (String rowStr : rows) {
            String[] vals = rowStr.split(",");
            double[] row = new double[vals.length];
            for (int i = 0; i < vals.length; i++) {
                row[i] = Double.parseDouble(vals[i]);
            }
            matrixList.add(row);
        }
        return matrixList.toArray(new double[0][]);
    }
}
def solve():
    x_str = input()
    y_str = input()
    n_iter = int(input())
    lr = float(input())
    alpha = float(input())

    X = [[float(v) for v in row.split(',')] for row in x_str.split(';')]
    Y = [[float(v) for v in row.split(',')] for row in y_str.split(';')]

    m = len(X)
    d = len(X[0])

    w = [0.0] * d
    b_ctr = 0.0
    b_cvr = 0.0

    for _ in range(n_iter):
        y_hat_ctr = [0.0] * m
        y_hat_cvr = [0.0] * m

        for i in range(m):
            dot_product = sum(X[i][j] * w[j] for j in range(d))
            y_hat_ctr[i] = dot_product + b_ctr
            y_hat_cvr[i] = dot_product + b_cvr

        error_ctr = [(y_hat_ctr[i] - Y[i][0]) for i in range(m)]
        error_cvr = [(y_hat_cvr[i] - Y[i][1]) for i in range(m)]

        grad_w = [0.0] * d
        grad_b_ctr = sum(error_ctr)
        grad_b_cvr = sum(error_cvr)

        for j in range(d):
            s = 0
            for i in range(m):
                s += error_ctr[i] * X[i][j] + alpha * error_cvr[i] * X[i][j]
            grad_w[j] = s
        
        for j in range(d):
            w[j] -= lr * (2.0 / m) * grad_w[j]
        
        b_ctr -= lr * (2.0 / m) * grad_b_ctr
        b_cvr -= lr * (2.0 * alpha / m) * grad_b_cvr

    mse_ctr = 0.0
    mse_cvr = 0.0
    for i in range(m):
        dot_product = sum(X[i][j] * w[j] for j in range(d))
        y_hat_ctr_final = dot_product + b_ctr
        y_hat_cvr_final = dot_product + b_cvr
        mse_ctr += (y_hat_ctr_final - Y[i][0]) ** 2
        mse_cvr += (y_hat_cvr_final - Y[i][1]) ** 2
    
    mse_ctr /= m
    mse_cvr /= m

    loss = mse_ctr + alpha * mse_cvr
    result = round(loss * 1e10)
    
    print(result)

solve()

算法及复杂度

  • 算法: 批量梯度下降 (Batch Gradient Descent) 模拟
  • 时间复杂度: ,其中 是迭代次数, 是样本数量, 是特征维度。每次迭代都需要计算所有样本的预测值和梯度,其中矩阵-向量乘法是主要开销,为 。这个过程重复 次。
  • 空间复杂度: ,主要用于存储输入的特征矩阵 X 和标签矩阵 Y