二分类逻辑回归
题目分析
本题要求从零实现一个二分类逻辑回归模型,包含以下要素:
- Sigmoid 激活函数:
- 带 L2 正则化的平均交叉熵损失
- 批量梯度下降优化
- 收敛判定:达到最大迭代次数,或相邻两次损失变化小于阈值
思路讲解
模型定义
逻辑回归的预测公式为:
$$
$$
损失函数
平均交叉熵 + L2 正则化:
$$
梯度更新
对每个权重 和偏置
:
$$
$$
每轮迭代后更新参数:,
。
实现要点
- 初始化:所有权重
和偏置
均初始化为 0。
- Sigmoid 数值稳定性:当
时使用
避免溢出。
- 收敛判定:先计算初始损失,每轮更新后计算新损失,若
则停止迭代。
- 预测:概率
预测为 1,否则为 0;概率保留 4 位小数输出。
代码实现
import sys
import math
def sigmoid(z):
if z >= 0:
return 1.0 / (1.0 + math.exp(-z))
else:
ez = math.exp(z)
return ez / (1.0 + ez)
def main():
input_data = sys.stdin.read().split()
idx = 0
n = int(input_data[idx]); idx += 1
max_iter = int(input_data[idx]); idx += 1
alpha = float(input_data[idx]); idx += 1
lam = float(input_data[idx]); idx += 1
tol = float(input_data[idx]); idx += 1
X = []
y = []
for i in range(n):
a = float(input_data[idx]); idx += 1
inc = float(input_data[idx]); idx += 1
dur = float(input_data[idx]); idx += 1
label = int(input_data[idx]); idx += 1
X.append([a, inc, dur])
y.append(label)
m = int(input_data[idx]); idx += 1
test = []
for i in range(m):
a = float(input_data[idx]); idx += 1
inc = float(input_data[idx]); idx += 1
dur = float(input_data[idx]); idx += 1
test.append([a, inc, dur])
w = [0.0, 0.0, 0.0]
b = 0.0
def compute_loss():
total = 0.0
for i in range(n):
z = sum(w[j] * X[i][j] for j in range(3)) + b
p = sigmoid(z)
p = max(min(p, 1 - 1e-15), 1e-15)
total += -(y[i] * math.log(p) + (1 - y[i]) * math.log(1 - p))
avg_loss = total / n
reg = (lam / 2.0) * sum(wj * wj for wj in w)
return avg_loss + reg
prev_loss = compute_loss() if max_iter > 0 else 0.0
for it in range(max_iter):
dw = [0.0, 0.0, 0.0]
db = 0.0
for i in range(n):
z = sum(w[j] * X[i][j] for j in range(3)) + b
p = sigmoid(z)
err = p - y[i]
for j in range(3):
dw[j] += err * X[i][j]
db += err
for j in range(3):
dw[j] = dw[j] / n + lam * w[j]
db = db / n
for j in range(3):
w[j] -= alpha * dw[j]
b -= alpha * db
cur_loss = compute_loss()
if abs(cur_loss - prev_loss) < tol:
break
prev_loss = cur_loss
results = []
for t in test:
z = sum(w[j] * t[j] for j in range(3)) + b
p = sigmoid(z)
pred = 1 if p >= 0.5 else 0
results.append(f"{pred} {p:.4f}")
print('\n'.join(results))
main()
复杂度分析
- 时间复杂度:
,其中
为实际迭代次数,
为训练样本数,
为特征维度。
- 空间复杂度:
,存储训练数据。

京公网安备 11010502036488号