题目链接

三项评分线性定价

题目描述

给定 组已上市手机的三项评分(硬件能力 、系统流畅度 、智能能力 )与售价 。请使用线性模型 来估计参数

具体要求:

  1. 使用最小二乘法的闭式解(正规方程)来拟合参数。
  2. 给定 个新机型的评分,预测其售价并输出四舍五入后的整数。
  3. 输入保证数值稳定且存在唯一解。

输入:

  • 第 1 行:一个整数 ,表示已知样本数量。
  • 第 2 行:共 个整数,按样本顺序依次给出:
  • 第 3 行:一个整数 ,表示预测机型数量。
  • 第 4 行:共 个整数,按机型顺序依次给出:

输出:

  • 个整数,代表预测机型价格,空格分隔。

解题思路

本题是一个典型的多元线性回归问题。我们需要通过最小二乘法求解正规方程。

  1. 构建正规方程: 设设计矩阵为 ,其中每一行对应一个样本 。 设目标向量为 ,包含 个样本的价格。 线性模型可以表示为 ,其中 。 根据最小二乘法,参数的解析解为:

  2. 具体计算步骤

    • ,这是一个 的矩阵。
    • ,这是一个 的向量。
    • 求解线性方程组 。由于矩阵规模很小(仅为 ),可以使用高斯消元法
  3. 预测与输出

    • 得到权重向量 后,对于每个测试样本 ,计算预测价格
    • 对结果进行四舍五入(通常可以使用 函数或 )。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

// 高斯消元法求解 4x4 方程组 Aw = b
void solve(double A[4][4], double b[4], double w[4]) {
    int n = 4;
    for (int i = 0; i < n; i++) {
        // 寻找主元
        int pivot = i;
        for (int j = i + 1; j < n; j++) {
            if (abs(A[j][i]) > abs(A[pivot][i])) pivot = j;
        }
        for (int j = i; j < n; j++) swap(A[i][j], A[pivot][j]);
        swap(b[i], b[pivot]);

        // 消元
        for (int j = i + 1; j < n; j++) {
            double factor = A[j][i] / A[i][i];
            b[j] -= factor * b[i];
            for (int k = i; k < n; k++) {
                A[j][k] -= factor * A[i][k];
            }
        }
    }

    // 回代
    for (int i = n - 1; i >= 0; i--) {
        double sum = 0;
        for (int j = i + 1; j < n; j++) {
            sum += A[i][j] * w[j];
        }
        w[i] = (b[i] - sum) / A[i][i];
    }
}

int main() {
    int k;
    cin >> k;
    double A[4][4] = {0};
    double b[4] = {0};
    
    for (int i = 0; i < k; i++) {
        double x[4];
        x[0] = 1.0;
        cin >> x[1] >> x[2] >> x[3];
        double price;
        cin >> price;
        // 累加计算 X^T * X 和 X^T * y
        for (int r = 0; r < 4; r++) {
            for (int c = 0; c < 4; c++) {
                A[r][c] += x[r] * x[c];
            }
            b[r] += x[r] * price;
        }
    }

    double w[4];
    solve(A, b, w);

    int n;
    cin >> n;
    for (int i = 0; i < n; i++) {
        double x1, x2, x3;
        cin >> x1 >> x2 >> x3;
        double res = w[0] + w[1] * x1 + w[2] * x2 + w[3] * x3;
        cout << (long long)round(res) << (i == n - 1 ? "" : " ");
    }
    cout << "\n";

    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int k = sc.nextInt();
        double[][] A = new double[4][4];
        double[] b = new double[4];

        for (int i = 0; i < k; i++) {
            double[] x = new double[4];
            x[0] = 1.0;
            x[1] = sc.nextDouble();
            x[2] = sc.nextDouble();
            x[3] = sc.nextDouble();
            double price = sc.nextDouble();

            for (int r = 0; r < 4; r++) {
                for (int c = 0; c < 4; c++) {
                    A[r][c] += x[r] * x[c];
                }
                b[r] += x[r] * price;
            }
        }

        double[] w = solve(A, b);

        int n = sc.nextInt();
        for (int i = 0; i < n; i++) {
            double x1 = sc.nextDouble();
            double x2 = sc.nextDouble();
            double x3 = sc.nextDouble();
            double res = w[0] + w[1] * x1 + w[2] * x2 + w[3] * x3;
            System.out.print(Math.round(res) + (i == n - 1 ? "" : " "));
        }
        System.out.println();
    }

    private static double[] solve(double[][] A, double[] b) {
        int n = 4;
        for (int i = 0; i < n; i++) {
            int pivot = i;
            for (int j = i + 1; j < n; j++) {
                if (Math.abs(A[j][i]) > Math.abs(A[pivot][i])) pivot = j;
            }
            double[] tempRow = A[i];
            A[i] = A[pivot];
            A[pivot] = tempRow;
            double tempB = b[i];
            b[i] = b[pivot];
            b[pivot] = tempB;

            for (int j = i + 1; j < n; j++) {
                double factor = A[j][i] / A[i][i];
                b[j] -= factor * b[i];
                for (int m = i; m < n; m++) {
                    A[j][m] -= factor * A[i][m];
                }
            }
        }

        double[] w = new double[n];
        for (int i = n - 1; i >= 0; i--) {
            double sum = 0;
            for (int j = i + 1; j < n; j++) {
                sum += A[i][j] * w[j];
            }
            w[i] = (b[i] - sum) / A[i][i];
        }
        return w;
    }
}
import math

def solve_linear_system(A, b):
    n = 4
    for i in range(n):
        pivot = i
        for j in range(i + 1, n):
            if abs(A[j][i]) > abs(A[pivot][i]):
                pivot = j
        A[i], A[pivot] = A[pivot], A[i]
        b[i], b[pivot] = b[pivot], b[i]

        for j in range(i + 1, n):
            factor = A[j][i] / A[i][i]
            b[j] -= factor * b[i]
            for k in range(i, n):
                A[j][k] -= factor * A[i][k]

    w = [0] * n
    for i in range(n - 1, -1, -1):
        s = sum(A[i][j] * w[j] for j in range(i + 1, n))
        w[i] = (b[i] - s) / A[i][i]
    return w

def solve():
    k = int(input())
    data = []
    # 连续读取所有样本数据
    while len(data) < 4 * k:
        data.extend(map(float, input().split()))
    
    A = [[0.0] * 4 for _ in range(4)]
    b = [0.0] * 4
    
    for i in range(k):
        x = [1.0, data[4*i], data[4*i+1], data[4*i+2]]
        price = data[4*i+3]
        for r in range(4):
            for c in range(4):
                A[r][c] += x[r] * x[c]
            b[r] += x[r] * price
            
    w = solve_linear_system(A, b)
    
    n = int(input())
    test_data = []
    while len(test_data) < 3 * n:
        test_data.extend(map(float, input().split()))
        
    preds = []
    for i in range(n):
        x1, x2, x3 = test_data[3*i], test_data[3*i+1], test_data[3*i+2]
        res = w[0] + w[1] * x1 + w[2] * x2 + w[3] * x3
        preds.append(str(int(res + 0.5)))
        
    print(" ".join(preds))

if __name__ == "__main__":
    solve()

算法及复杂度

  • 算法:最小二乘法(正规方程) + 高斯消元法。
  • 时间复杂度:,其中 为特征维度(本题 )。构建 矩阵耗时 ,高斯消元耗时 ,预测耗时
  • 空间复杂度:。需要存储 的矩阵及相关向量。