题目链接
题目描述
本题要求实现一个设备故障预测程序。你需要基于给定的训练数据集(包含设备ID、写入次数、读取次数、平均读写延迟、使用年限和设备状态),通过数据清洗和处理,学习一个逻辑回归模型。然后,使用这个模型对新的待预测设备数据进行故障预测(判断为正常或故障)。
核心任务包括:
- 数据清洗:
- 缺失值填充:对数值字段中的
NaN
字符串,使用训练集中该字段有效数值的均值进行填充。 - 异常值处理:对超出指定范围(如写入次数小于0)的数值,使用训练集中该字段有效数值的中位数进行替换。
- 标签缺失处理:丢弃状态字段缺失或无效的训练样本。
- 缺失值填充:对数值字段中的
- 模型训练:
- 使用二分类逻辑回归模型,包含偏置项
。
- 采用批量梯度下降法(Batch GD)进行训练,学习率为
,迭代
次,初始权重全部为
。
- 使用二分类逻辑回归模型,包含偏置项
- 预测:
- 根据学习到的权重,计算待预测样本发生故障的概率
。
- 若概率大于等于
,则判定为故障(输出
1
),否则判定为正常(输出0
)。
- 根据学习到的权重,计算待预测样本发生故障的概率
解题思路
这是一个典型的机器学习流程模拟题,需要严格按照题目要求实现数据处理、模型训练和预测三个步骤。
-
数据读取与初步筛选
- 首先,读取训练样本数量
和所有训练数据。
- 对每一行训练数据进行解析。由于数据可能包含
NaN
字符串,我们需要将特征暂时存为字符串或能够表示缺失值的形式。 - 筛选掉标签(
status
)字段缺失或无法解析为或
的样本。这些样本不参与后续任何计算。
- 首先,读取训练样本数量
-
计算清洗参数
- 遍历筛选后的有效训练样本,为
个数值特征(
writes
,reads
,avg_write_ms
,avg_read_ms
,years
)分别建立一个列表,用于存放“有效数值”。 - 一个数值是“有效的”,当且仅当它不是
NaN
且不在题目定义的异常值范围内。 - 对每个特征的有效数值列表,计算其均值和中位数。
- 均值:用于后续填充
NaN
。 - 中位数:用于后续修正异常值。
- 均值:用于后续填充
- 如果某个特征在训练集中没有任何有效数值,则其均值和中位数都按
处理。
- 遍历筛选后的有效训练样本,为
-
数据清洗
- 使用上一步计算出的均值和中位数,对全部训练数据和测试数据进行清洗。
- 遍历所有样本(包括训练和测试)的每一个特征:
- 如果值为
NaN
,则替换为对应特征的均值。 - 如果值为异常值,则替换为对应特征的中位数。
- 如果值为
- 经过此步骤后,所有数据都变为干净的数值型数据,可以用于模型训练。
-
模型训练(批量梯度下降)
- 逻辑回归模型的预测公式为
,其中
。
- 初始化权重向量
(包含
到
共
个权重)为全
。
- 执行
次迭代:
- 在每次迭代开始时,初始化一个梯度向量
grad
为全。
- 遍历所有
个清洗后的训练样本:
- 计算当前样本的预测概率
。
- 计算预测误差
error = p - y
(其中是真实标签)。
- 累加梯度:
grad[0] += error
,(对于
)。
- 计算当前样本的预测概率
- 遍历结束后,使用累加的梯度更新权重:
,其中学习率
。
- 在每次迭代开始时,初始化一个梯度向量
- 逻辑回归模型的预测公式为
-
预测
- 读取并清洗测试数据。
- 对每一个清洗后的测试样本,使用最终训练得到的权重
计算
值和概率
。
- 根据阈值
作出判断:如果
,输出
;否则输出
。
代码
#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <cmath>
#include <numeric>
#include <algorithm>
#include <iomanip>
using namespace std;
// 数据样本结构体
struct Sample {
string id;
vector<double> features;
int status;
bool valid_label;
vector<bool> is_nan;
};
// 解析CSV行
Sample parse_line(const string& line, bool is_training) {
Sample s;
s.valid_label = true;
stringstream ss(line);
string item;
getline(ss, s.id, ',');
s.features.resize(5);
s.is_nan.resize(5, false);
for (int i = 0; i < 5; ++i) {
getline(ss, item, ',');
if (item == "NaN") {
s.is_nan[i] = true;
s.features[i] = 0.0;
} else {
try {
s.features[i] = stod(item);
} catch (...) {
s.is_nan[i] = true;
s.features[i] = 0.0;
}
}
}
if (is_training) {
if (getline(ss, item, ',')) {
try {
s.status = stoi(item);
if (s.status != 0 && s.status != 1) {
s.valid_label = false;
}
} catch (...) {
s.valid_label = false;
}
} else {
s.valid_label = false;
}
}
return s;
}
// 检查是否为异常值
bool is_outlier(double val, int feature_idx) {
if (feature_idx <= 1) return val < 0; // writes, reads
if (feature_idx <= 3) return val < 0 || val > 1000; // avg_write_ms, avg_read_ms
if (feature_idx == 4) return val < 0 || val > 20; // years
return false;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
string line;
getline(cin, line);
vector<Sample> training_samples;
for (int i = 0; i < n; ++i) {
getline(cin, line);
Sample s = parse_line(line, true);
if (s.valid_label) {
training_samples.push_back(s);
}
}
vector<double> means(5, 0.0);
vector<double> medians(5, 0.0);
for (int i = 0; i < 5; ++i) {
vector<double> valid_values;
for (const auto& s : training_samples) {
if (!s.is_nan[i] && !is_outlier(s.features[i], i)) {
valid_values.push_back(s.features[i]);
}
}
if (!valid_values.empty()) {
double sum = accumulate(valid_values.begin(), valid_values.end(), 0.0);
means[i] = sum / valid_values.size();
sort(valid_values.begin(), valid_values.end());
medians[i] = valid_values[valid_values.size() / 2];
if (valid_values.size() % 2 == 0) {
medians[i] = (valid_values[valid_values.size() / 2 - 1] + valid_values[valid_values.size() / 2]) / 2.0;
}
}
}
auto clean_data = [&](vector<Sample>& samples) {
for (auto& s : samples) {
for (int i = 0; i < 5; ++i) {
if (s.is_nan[i]) {
s.features[i] = means[i];
} else if (is_outlier(s.features[i], i)) {
s.features[i] = medians[i];
}
}
}
};
clean_data(training_samples);
vector<double> w(6, 0.0);
double learning_rate = 0.01;
int iterations = 100;
int num_train_samples = training_samples.size();
for (int iter = 0; iter < iterations; ++iter) {
vector<double> gradients(6, 0.0);
for (const auto& s : training_samples) {
double z = w[0];
for (int i = 0; i < 5; ++i) {
z += w[i + 1] * s.features[i];
}
double p = 1.0 / (1.0 + exp(-z));
double error = p - s.status;
gradients[0] += error;
for (int i = 0; i < 5; ++i) {
gradients[i + 1] += error * s.features[i];
}
}
for (int i = 0; i < 6; ++i) {
if (num_train_samples > 0) {
w[i] -= learning_rate * gradients[i] / num_train_samples;
}
}
}
int m;
cin >> m;
getline(cin, line);
vector<Sample> test_samples;
for (int i = 0; i < m; ++i) {
getline(cin, line);
test_samples.push_back(parse_line(line, false));
}
clean_data(test_samples);
for (const auto& s : test_samples) {
double z = w[0];
for (int i = 0; i < 5; ++i) {
z += w[i + 1] * s.features[i];
}
double p = 1.0 / (1.0 + exp(-z));
cout << (p >= 0.5 ? 1 : 0) << "\n";
}
return 0;
}
import java.util.*;
import java.io.*;
public class Main {
static class Sample {
String id;
double[] features = new double[5];
int status;
boolean validLabel = true;
boolean[] isNan = new boolean[5];
}
static boolean isOutlier(double val, int featureIdx) {
if (featureIdx <= 1) return val < 0; // writes, reads
if (featureIdx <= 3) return val < 0 || val > 1000; // avg_write_ms, avg_read_ms
if (featureIdx == 4) return val < 0 || val > 20; // years
return false;
}
static Sample parseLine(String line, boolean isTraining) {
Sample s = new Sample();
String[] parts = line.split(",");
s.id = parts[0];
for (int i = 0; i < 5; i++) {
if (parts[i + 1].equals("NaN")) {
s.isNan[i] = true;
s.features[i] = 0.0;
} else {
try {
s.features[i] = Double.parseDouble(parts[i + 1]);
} catch (NumberFormatException e) {
s.isNan[i] = true;
s.features[i] = 0.0;
}
}
}
if (isTraining) {
if (parts.length < 7) {
s.validLabel = false;
} else {
try {
s.status = Integer.parseInt(parts[6]);
if (s.status != 0 && s.status != 1) {
s.validLabel = false;
}
} catch (NumberFormatException e) {
s.validLabel = false;
}
}
}
return s;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = Integer.parseInt(sc.nextLine());
List<Sample> trainingSamples = new ArrayList<>();
for (int i = 0; i < n; i++) {
String line = sc.nextLine();
Sample s = parseLine(line, true);
if (s.validLabel) {
trainingSamples.add(s);
}
}
double[] means = new double[5];
double[] medians = new double[5];
for (int i = 0; i < 5; i++) {
List<Double> validValues = new ArrayList<>();
for (Sample s : trainingSamples) {
if (!s.isNan[i] && !isOutlier(s.features[i], i)) {
validValues.add(s.features[i]);
}
}
if (!validValues.isEmpty()) {
double sum = 0;
for (double val : validValues) {
sum += val;
}
means[i] = sum / validValues.size();
Collections.sort(validValues);
if (validValues.size() % 2 == 1) {
medians[i] = validValues.get(validValues.size() / 2);
} else {
medians[i] = (validValues.get(validValues.size() / 2 - 1) + validValues.get(validValues.size() / 2)) / 2.0;
}
}
}
// 清洗数据
for (Sample s : trainingSamples) {
for (int i = 0; i < 5; i++) {
if (s.isNan[i]) {
s.features[i] = means[i];
} else if (isOutlier(s.features[i], i)) {
s.features[i] = medians[i];
}
}
}
// 训练
double[] w = new double[6];
double learningRate = 0.01;
int iterations = 100;
int numTrainSamples = trainingSamples.size();
for (int iter = 0; iter < iterations; iter++) {
double[] gradients = new double[6];
for (Sample s : trainingSamples) {
double z = w[0];
for (int i = 0; i < 5; i++) {
z += w[i + 1] * s.features[i];
}
double p = 1.0 / (1.0 + Math.exp(-z));
double error = p - s.status;
gradients[0] += error;
for (int i = 0; i < 5; i++) {
gradients[i + 1] += error * s.features[i];
}
}
if (numTrainSamples > 0) {
for (int i = 0; i < 6; i++) {
w[i] -= learningRate * gradients[i] / numTrainSamples;
}
}
}
int m = Integer.parseInt(sc.nextLine());
List<Sample> testSamples = new ArrayList<>();
for (int i = 0; i < m; i++) {
testSamples.add(parseLine(sc.nextLine(), false));
}
for (Sample s : testSamples) {
for (int i = 0; i < 5; i++) {
if (s.isNan[i]) {
s.features[i] = means[i];
} else if (isOutlier(s.features[i], i)) {
s.features[i] = medians[i];
}
}
}
// 预测
for (Sample s : testSamples) {
double z = w[0];
for (int i = 0; i < 5; i++) {
z += w[i + 1] * s.features[i];
}
double p = 1.0 / (1.0 + Math.exp(-z));
System.out.println(p >= 0.5 ? 1 : 0);
}
}
}
import math
import statistics
# 检查是否为异常值
def is_outlier(val, feature_idx):
if feature_idx <= 1: # writes, reads
return val < 0
if feature_idx <= 3: # avg_write_ms, avg_read_ms
return val < 0 or val > 1000
if feature_idx == 4: # years
return val < 0 or val > 20
return False
# 解析行数据
def parse_line(line, is_training):
parts = line.strip().split(',')
device_id = parts[0]
features = []
is_nan = []
for i in range(1, 6):
if parts[i] == 'NaN':
features.append(0.0) # 临时值
is_nan.append(True)
else:
features.append(float(parts[i]))
is_nan.append(False)
status = -1
valid_label = True
if is_training:
if len(parts) < 7:
valid_label = False
else:
try:
status = int(parts[6])
if status not in [0, 1]:
valid_label = False
except ValueError:
valid_label = False
return device_id, features, status, valid_label, is_nan
def solve():
n = int(input())
training_samples = []
for _ in range(n):
line = input()
device_id, features, status, valid_label, is_nan = parse_line(line, True)
if valid_label:
training_samples.append({'id': device_id, 'features': features, 'status': status, 'is_nan': is_nan})
means = [0.0] * 5
medians = [0.0] * 5
for i in range(5):
valid_values = []
for sample in training_samples:
if not sample['is_nan'][i] and not is_outlier(sample['features'][i], i):
valid_values.append(sample['features'][i])
if valid_values:
means[i] = statistics.mean(valid_values)
medians[i] = statistics.median(valid_values)
def clean_data(samples):
for sample in samples:
for i in range(5):
if sample['is_nan'][i]:
sample['features'][i] = means[i]
elif is_outlier(sample['features'][i], i):
sample['features'][i] = medians[i]
clean_data(training_samples)
w = [0.0] * 6 # w0...w5
learning_rate = 0.01
iterations = 100
num_train_samples = len(training_samples)
for _ in range(iterations):
gradients = [0.0] * 6
if num_train_samples == 0:
break
for sample in training_samples:
z = w[0] + sum(w[i+1] * sample['features'][i] for i in range(5))
try:
p = 1.0 / (1.0 + math.exp(-z))
except OverflowError:
p = 0.0 if z < 0 else 1.0
error = p - sample['status']
gradients[0] += error
for i in range(5):
gradients[i+1] += error * sample['features'][i]
for i in range(6):
w[i] -= learning_rate * gradients[i] / num_train_samples
m = int(input())
test_samples = []
for _ in range(m):
line = input()
device_id, features, _, _, is_nan = parse_line(line, False)
test_samples.append({'id': device_id, 'features': features, 'is_nan': is_nan})
clean_data(test_samples)
for sample in test_samples:
z = w[0] + sum(w[i+1] * sample['features'][i] for i in range(5))
try:
p = 1.0 / (1.0 + math.exp(-z))
except OverflowError:
p = 0.0 if z < 0 else 1.0
print(1 if p >= 0.5 else 0)
solve()
算法及复杂度
- 算法:机器学习、逻辑回归、批量梯度下降
- 时间复杂度:
,其中
是迭代次数(
),
是有效训练样本数,
是特征数量(
)。第一项是模型训练的时间,第二项是为计算中位数而对每个特征的有效值进行排序的时间。
- 空间复杂度:
,需要存储所有训练和测试样本的数据,其中
是测试样本数。