题目链接
题目描述
给定 组已上市手机的三项评分(硬件能力
、系统流畅度
、智能能力
)与售价
。请使用线性模型
来估计参数
。
具体要求:
- 使用最小二乘法的闭式解(正规方程)来拟合参数。
- 给定
个新机型的评分,预测其售价并输出四舍五入后的整数。
- 输入保证数值稳定且存在唯一解。
输入:
- 第 1 行:一个整数
,表示已知样本数量。
- 第 2 行:共
个整数,按样本顺序依次给出:
。
- 第 3 行:一个整数
,表示预测机型数量。
- 第 4 行:共
个整数,按机型顺序依次给出:
。
输出:
个整数,代表预测机型价格,空格分隔。
解题思路
本题是一个典型的多元线性回归问题。我们需要通过最小二乘法求解正规方程。
-
构建正规方程: 设设计矩阵为
,其中每一行对应一个样本
。 设目标向量为
,包含
个样本的价格。 线性模型可以表示为
,其中
。 根据最小二乘法,参数的解析解为:
-
具体计算步骤:
- 令
,这是一个
的矩阵。
- 令
,这是一个
的向量。
- 求解线性方程组
。由于矩阵规模很小(仅为
),可以使用高斯消元法。
- 令
-
预测与输出:
- 得到权重向量
后,对于每个测试样本
,计算预测价格
。
- 对结果进行四舍五入(通常可以使用
函数或
)。
- 得到权重向量
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;
// 高斯消元法求解 4x4 方程组 Aw = b
void solve(double A[4][4], double b[4], double w[4]) {
int n = 4;
for (int i = 0; i < n; i++) {
// 寻找主元
int pivot = i;
for (int j = i + 1; j < n; j++) {
if (abs(A[j][i]) > abs(A[pivot][i])) pivot = j;
}
for (int j = i; j < n; j++) swap(A[i][j], A[pivot][j]);
swap(b[i], b[pivot]);
// 消元
for (int j = i + 1; j < n; j++) {
double factor = A[j][i] / A[i][i];
b[j] -= factor * b[i];
for (int k = i; k < n; k++) {
A[j][k] -= factor * A[i][k];
}
}
}
// 回代
for (int i = n - 1; i >= 0; i--) {
double sum = 0;
for (int j = i + 1; j < n; j++) {
sum += A[i][j] * w[j];
}
w[i] = (b[i] - sum) / A[i][i];
}
}
int main() {
int k;
cin >> k;
double A[4][4] = {0};
double b[4] = {0};
for (int i = 0; i < k; i++) {
double x[4];
x[0] = 1.0;
cin >> x[1] >> x[2] >> x[3];
double price;
cin >> price;
// 累加计算 X^T * X 和 X^T * y
for (int r = 0; r < 4; r++) {
for (int c = 0; c < 4; c++) {
A[r][c] += x[r] * x[c];
}
b[r] += x[r] * price;
}
}
double w[4];
solve(A, b, w);
int n;
cin >> n;
for (int i = 0; i < n; i++) {
double x1, x2, x3;
cin >> x1 >> x2 >> x3;
double res = w[0] + w[1] * x1 + w[2] * x2 + w[3] * x3;
cout << (long long)round(res) << (i == n - 1 ? "" : " ");
}
cout << "\n";
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int k = sc.nextInt();
double[][] A = new double[4][4];
double[] b = new double[4];
for (int i = 0; i < k; i++) {
double[] x = new double[4];
x[0] = 1.0;
x[1] = sc.nextDouble();
x[2] = sc.nextDouble();
x[3] = sc.nextDouble();
double price = sc.nextDouble();
for (int r = 0; r < 4; r++) {
for (int c = 0; c < 4; c++) {
A[r][c] += x[r] * x[c];
}
b[r] += x[r] * price;
}
}
double[] w = solve(A, b);
int n = sc.nextInt();
for (int i = 0; i < n; i++) {
double x1 = sc.nextDouble();
double x2 = sc.nextDouble();
double x3 = sc.nextDouble();
double res = w[0] + w[1] * x1 + w[2] * x2 + w[3] * x3;
System.out.print(Math.round(res) + (i == n - 1 ? "" : " "));
}
System.out.println();
}
private static double[] solve(double[][] A, double[] b) {
int n = 4;
for (int i = 0; i < n; i++) {
int pivot = i;
for (int j = i + 1; j < n; j++) {
if (Math.abs(A[j][i]) > Math.abs(A[pivot][i])) pivot = j;
}
double[] tempRow = A[i];
A[i] = A[pivot];
A[pivot] = tempRow;
double tempB = b[i];
b[i] = b[pivot];
b[pivot] = tempB;
for (int j = i + 1; j < n; j++) {
double factor = A[j][i] / A[i][i];
b[j] -= factor * b[i];
for (int m = i; m < n; m++) {
A[j][m] -= factor * A[i][m];
}
}
}
double[] w = new double[n];
for (int i = n - 1; i >= 0; i--) {
double sum = 0;
for (int j = i + 1; j < n; j++) {
sum += A[i][j] * w[j];
}
w[i] = (b[i] - sum) / A[i][i];
}
return w;
}
}
import math
def solve_linear_system(A, b):
n = 4
for i in range(n):
pivot = i
for j in range(i + 1, n):
if abs(A[j][i]) > abs(A[pivot][i]):
pivot = j
A[i], A[pivot] = A[pivot], A[i]
b[i], b[pivot] = b[pivot], b[i]
for j in range(i + 1, n):
factor = A[j][i] / A[i][i]
b[j] -= factor * b[i]
for k in range(i, n):
A[j][k] -= factor * A[i][k]
w = [0] * n
for i in range(n - 1, -1, -1):
s = sum(A[i][j] * w[j] for j in range(i + 1, n))
w[i] = (b[i] - s) / A[i][i]
return w
def solve():
k = int(input())
data = []
# 连续读取所有样本数据
while len(data) < 4 * k:
data.extend(map(float, input().split()))
A = [[0.0] * 4 for _ in range(4)]
b = [0.0] * 4
for i in range(k):
x = [1.0, data[4*i], data[4*i+1], data[4*i+2]]
price = data[4*i+3]
for r in range(4):
for c in range(4):
A[r][c] += x[r] * x[c]
b[r] += x[r] * price
w = solve_linear_system(A, b)
n = int(input())
test_data = []
while len(test_data) < 3 * n:
test_data.extend(map(float, input().split()))
preds = []
for i in range(n):
x1, x2, x3 = test_data[3*i], test_data[3*i+1], test_data[3*i+2]
res = w[0] + w[1] * x1 + w[2] * x2 + w[3] * x3
preds.append(str(int(res + 0.5)))
print(" ".join(preds))
if __name__ == "__main__":
solve()
算法及复杂度
- 算法:最小二乘法(正规方程) + 高斯消元法。
- 时间复杂度:
,其中
为特征维度(本题
)。构建
矩阵耗时
,高斯消元耗时
,预测耗时
。
- 空间复杂度:
。需要存储
的矩阵及相关向量。

京公网安备 11010502036488号