多项式岭回归
题目分析
智慧城市供水监测系统有 天的连续用水量数据,其中
条记录缺失(标记为 Gap)。对每个缺失位置,用其前后最近的连续真实数据段拟合二阶多项式岭回归模型来预测。
思路
逐 Gap 独立建模预测
对每个缺失位置 ,确定训练数据的范围:
- 前向区间
:从
向前搜索,遇到第一个 Gap 时停止,
为该 Gap 的下一天;若无 Gap 则
- 后向区间
:从
向后搜索,遇到第一个 Gap 时停止,
为该 Gap 的前一天;若无 Gap 则
合并两个区间的真实数据作为训练集,用天数 (绝对天号)构造设计矩阵
(每行为
),通过岭回归公式求解系数:
$$
其中 ,
为
单位矩阵。最后用
在
处预测。
实现细节: 只有
,直接用高斯消元法求解线性方程组即可,无需引入矩阵库。
复杂度
- 时间复杂度:
,每个 Gap 需遍历前后区间收集数据
- 空间复杂度:
,存储数据数组
代码
import sys
def main():
lines = sys.stdin.read().strip().split('\n')
first = lines[0].split()
M = int(first[0])
N = int(first[1])
values = [None] * (N + 1)
gap_positions = []
gap_set = set()
for i in range(1, N + 1):
s = lines[i].strip()
if s.startswith("Gap_"):
gap_positions.append((s, i))
gap_set.add(i)
else:
values[i] = float(s)
lam = 0.1
results = []
for gap_name, pos in gap_positions:
# 前向区间
left_start = 1
for d in range(pos - 1, 0, -1):
if d in gap_set:
left_start = d + 1
break
# 后向区间
right_end = N
for d in range(pos + 1, N + 1):
if d in gap_set:
right_end = d - 1
break
# 收集训练数据
days, vals = [], []
for d in range(left_start, pos):
if values[d] is not None:
days.append(d)
vals.append(values[d])
for d in range(pos + 1, right_end + 1):
if values[d] is not None:
days.append(d)
vals.append(values[d])
n = len(days)
# 构造 X^T X + λI 和 X^T y
XtX = [[0.0] * 3 for _ in range(3)]
Xty = [0.0] * 3
for idx in range(n):
x = days[idx]
y = vals[idx]
row = [x * x, x, 1.0]
for a in range(3):
for b in range(3):
XtX[a][b] += row[a] * row[b]
Xty[a] += row[a] * y
for i in range(3):
XtX[i][i] += lam
# 高斯消元求解
A = [XtX[i][:] + [Xty[i]] for i in range(3)]
for col in range(3):
max_row = col
for row in range(col + 1, 3):
if abs(A[row][col]) > abs(A[max_row][col]):
max_row = row
A[col], A[max_row] = A[max_row], A[col]
pivot = A[col][col]
for row in range(col + 1, 3):
factor = A[row][col] / pivot
for j in range(col, 4):
A[row][j] -= factor * A[col][j]
beta = [0.0] * 3
for i in range(2, -1, -1):
beta[i] = A[i][3]
for j in range(i + 1, 3):
beta[i] -= A[i][j] * beta[j]
beta[i] /= A[i][i]
pred = beta[0] * pos * pos + beta[1] * pos + beta[2]
results.append(f"{gap_name}: {pred:.2f}")
print('\n'.join(results))
main()

京公网安备 11010502036488号