题目链接

设备故障预测程序

题目描述

本题要求实现一个设备故障预测程序。你需要基于给定的训练数据集(包含设备ID、写入次数、读取次数、平均读写延迟、使用年限和设备状态),通过数据清洗和处理,学习一个逻辑回归模型。然后,使用这个模型对新的待预测设备数据进行故障预测(判断为正常或故障)。

核心任务包括:

  1. 数据清洗
    • 缺失值填充:对数值字段中的 NaN 字符串,使用训练集中该字段有效数值的均值进行填充。
    • 异常值处理:对超出指定范围(如写入次数小于0)的数值,使用训练集中该字段有效数值的中位数进行替换。
    • 标签缺失处理:丢弃状态字段缺失或无效的训练样本。
  2. 模型训练
    • 使用二分类逻辑回归模型,包含偏置项
    • 采用批量梯度下降法(Batch GD)进行训练,学习率为 ,迭代 次,初始权重全部为
  3. 预测
    • 根据学习到的权重,计算待预测样本发生故障的概率
    • 若概率大于等于 ,则判定为故障(输出 1),否则判定为正常(输出 0)。

解题思路

这是一个典型的机器学习流程模拟题,需要严格按照题目要求实现数据处理、模型训练和预测三个步骤。

  1. 数据读取与初步筛选

    • 首先,读取训练样本数量 和所有训练数据。
    • 对每一行训练数据进行解析。由于数据可能包含 NaN 字符串,我们需要将特征暂时存为字符串或能够表示缺失值的形式。
    • 筛选掉标签(status)字段缺失或无法解析为 的样本。这些样本不参与后续任何计算。
  2. 计算清洗参数

    • 遍历筛选后的有效训练样本,为 个数值特征(writes, reads, avg_write_ms, avg_read_ms, years)分别建立一个列表,用于存放“有效数值”。
    • 一个数值是“有效的”,当且仅当它不是 NaN 且不在题目定义的异常值范围内。
    • 对每个特征的有效数值列表,计算其均值和中位数。
      • 均值:用于后续填充 NaN
      • 中位数:用于后续修正异常值。
    • 如果某个特征在训练集中没有任何有效数值,则其均值和中位数都按 处理。
  3. 数据清洗

    • 使用上一步计算出的均值和中位数,对全部训练数据和测试数据进行清洗。
    • 遍历所有样本(包括训练和测试)的每一个特征:
      • 如果值为 NaN,则替换为对应特征的均值。
      • 如果值为异常值,则替换为对应特征的中位数。
    • 经过此步骤后,所有数据都变为干净的数值型数据,可以用于模型训练。
  4. 模型训练(批量梯度下降)

    • 逻辑回归模型的预测公式为 ,其中
    • 初始化权重向量 (包含 个权重)为全
    • 执行 次迭代:
      • 在每次迭代开始时,初始化一个梯度向量 grad 为全
      • 遍历所有 个清洗后的训练样本:
        • 计算当前样本的预测概率
        • 计算预测误差 error = p - y(其中 是真实标签)。
        • 累加梯度:grad[0] += error(对于 )。
      • 遍历结束后,使用累加的梯度更新权重:,其中学习率
  5. 预测

    • 读取并清洗测试数据。
    • 对每一个清洗后的测试样本,使用最终训练得到的权重 计算 值和概率
    • 根据阈值 作出判断:如果 ,输出 ;否则输出

代码

#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <cmath>
#include <numeric>
#include <algorithm>
#include <iomanip>

using namespace std;

// 数据样本结构体
struct Sample {
    string id;
    vector<double> features;
    int status;
    bool valid_label;
    vector<bool> is_nan;
};

// 解析CSV行
Sample parse_line(const string& line, bool is_training) {
    Sample s;
    s.valid_label = true;
    stringstream ss(line);
    string item;
    
    getline(ss, s.id, ',');
    
    s.features.resize(5);
    s.is_nan.resize(5, false);

    for (int i = 0; i < 5; ++i) {
        getline(ss, item, ',');
        if (item == "NaN") {
            s.is_nan[i] = true;
            s.features[i] = 0.0;
        } else {
            try {
                s.features[i] = stod(item);
            } catch (...) {
                s.is_nan[i] = true;
                s.features[i] = 0.0;
            }
        }
    }

    if (is_training) {
        if (getline(ss, item, ',')) {
            try {
                s.status = stoi(item);
                if (s.status != 0 && s.status != 1) {
                    s.valid_label = false;
                }
            } catch (...) {
                s.valid_label = false;
            }
        } else {
            s.valid_label = false;
        }
    }
    return s;
}

// 检查是否为异常值
bool is_outlier(double val, int feature_idx) {
    if (feature_idx <= 1) return val < 0; // writes, reads
    if (feature_idx <= 3) return val < 0 || val > 1000; // avg_write_ms, avg_read_ms
    if (feature_idx == 4) return val < 0 || val > 20; // years
    return false;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;
    string line;
    getline(cin, line); 

    vector<Sample> training_samples;
    for (int i = 0; i < n; ++i) {
        getline(cin, line);
        Sample s = parse_line(line, true);
        if (s.valid_label) {
            training_samples.push_back(s);
        }
    }

    vector<double> means(5, 0.0);
    vector<double> medians(5, 0.0);

    for (int i = 0; i < 5; ++i) {
        vector<double> valid_values;
        for (const auto& s : training_samples) {
            if (!s.is_nan[i] && !is_outlier(s.features[i], i)) {
                valid_values.push_back(s.features[i]);
            }
        }

        if (!valid_values.empty()) {
            double sum = accumulate(valid_values.begin(), valid_values.end(), 0.0);
            means[i] = sum / valid_values.size();
            sort(valid_values.begin(), valid_values.end());
            medians[i] = valid_values[valid_values.size() / 2];
            if (valid_values.size() % 2 == 0) {
                 medians[i] = (valid_values[valid_values.size() / 2 - 1] + valid_values[valid_values.size() / 2]) / 2.0;
            }
        }
    }

    auto clean_data = [&](vector<Sample>& samples) {
        for (auto& s : samples) {
            for (int i = 0; i < 5; ++i) {
                if (s.is_nan[i]) {
                    s.features[i] = means[i];
                } else if (is_outlier(s.features[i], i)) {
                    s.features[i] = medians[i];
                }
            }
        }
    };
    
    clean_data(training_samples);

    vector<double> w(6, 0.0);
    double learning_rate = 0.01;
    int iterations = 100;
    int num_train_samples = training_samples.size();

    for (int iter = 0; iter < iterations; ++iter) {
        vector<double> gradients(6, 0.0);
        for (const auto& s : training_samples) {
            double z = w[0];
            for (int i = 0; i < 5; ++i) {
                z += w[i + 1] * s.features[i];
            }
            double p = 1.0 / (1.0 + exp(-z));
            double error = p - s.status;
            gradients[0] += error;
            for (int i = 0; i < 5; ++i) {
                gradients[i + 1] += error * s.features[i];
            }
        }
        for (int i = 0; i < 6; ++i) {
            if (num_train_samples > 0) {
                w[i] -= learning_rate * gradients[i] / num_train_samples;
            }
        }
    }

    int m;
    cin >> m;
    getline(cin, line);
    vector<Sample> test_samples;
    for (int i = 0; i < m; ++i) {
        getline(cin, line);
        test_samples.push_back(parse_line(line, false));
    }
    
    clean_data(test_samples);

    for (const auto& s : test_samples) {
        double z = w[0];
        for (int i = 0; i < 5; ++i) {
            z += w[i + 1] * s.features[i];
        }
        double p = 1.0 / (1.0 + exp(-z));
        cout << (p >= 0.5 ? 1 : 0) << "\n";
    }

    return 0;
}
import java.util.*;
import java.io.*;

public class Main {

    static class Sample {
        String id;
        double[] features = new double[5];
        int status;
        boolean validLabel = true;
        boolean[] isNan = new boolean[5];
    }

    static boolean isOutlier(double val, int featureIdx) {
        if (featureIdx <= 1) return val < 0; // writes, reads
        if (featureIdx <= 3) return val < 0 || val > 1000; // avg_write_ms, avg_read_ms
        if (featureIdx == 4) return val < 0 || val > 20; // years
        return false;
    }
    
    static Sample parseLine(String line, boolean isTraining) {
        Sample s = new Sample();
        String[] parts = line.split(",");
        
        s.id = parts[0];
        
        for (int i = 0; i < 5; i++) {
            if (parts[i + 1].equals("NaN")) {
                s.isNan[i] = true;
                s.features[i] = 0.0;
            } else {
                try {
                    s.features[i] = Double.parseDouble(parts[i + 1]);
                } catch (NumberFormatException e) {
                    s.isNan[i] = true;
                    s.features[i] = 0.0;
                }
            }
        }
        
        if (isTraining) {
            if (parts.length < 7) {
                s.validLabel = false;
            } else {
                try {
                    s.status = Integer.parseInt(parts[6]);
                    if (s.status != 0 && s.status != 1) {
                        s.validLabel = false;
                    }
                } catch (NumberFormatException e) {
                    s.validLabel = false;
                }
            }
        }
        return s;
    }


    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = Integer.parseInt(sc.nextLine());
        
        List<Sample> trainingSamples = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            String line = sc.nextLine();
            Sample s = parseLine(line, true);
            if (s.validLabel) {
                trainingSamples.add(s);
            }
        }
        
        double[] means = new double[5];
        double[] medians = new double[5];

        for (int i = 0; i < 5; i++) {
            List<Double> validValues = new ArrayList<>();
            for (Sample s : trainingSamples) {
                if (!s.isNan[i] && !isOutlier(s.features[i], i)) {
                    validValues.add(s.features[i]);
                }
            }

            if (!validValues.isEmpty()) {
                double sum = 0;
                for (double val : validValues) {
                    sum += val;
                }
                means[i] = sum / validValues.size();
                
                Collections.sort(validValues);
                if (validValues.size() % 2 == 1) {
                    medians[i] = validValues.get(validValues.size() / 2);
                } else {
                    medians[i] = (validValues.get(validValues.size() / 2 - 1) + validValues.get(validValues.size() / 2)) / 2.0;
                }
            }
        }
        
        // 清洗数据
        for (Sample s : trainingSamples) {
            for (int i = 0; i < 5; i++) {
                if (s.isNan[i]) {
                    s.features[i] = means[i];
                } else if (isOutlier(s.features[i], i)) {
                    s.features[i] = medians[i];
                }
            }
        }
        
        // 训练
        double[] w = new double[6];
        double learningRate = 0.01;
        int iterations = 100;
        int numTrainSamples = trainingSamples.size();

        for (int iter = 0; iter < iterations; iter++) {
            double[] gradients = new double[6];
            for (Sample s : trainingSamples) {
                double z = w[0];
                for (int i = 0; i < 5; i++) {
                    z += w[i + 1] * s.features[i];
                }
                double p = 1.0 / (1.0 + Math.exp(-z));
                double error = p - s.status;
                gradients[0] += error;
                for (int i = 0; i < 5; i++) {
                    gradients[i + 1] += error * s.features[i];
                }
            }
            if (numTrainSamples > 0) {
                for (int i = 0; i < 6; i++) {
                    w[i] -= learningRate * gradients[i] / numTrainSamples;
                }
            }
        }

        int m = Integer.parseInt(sc.nextLine());
        List<Sample> testSamples = new ArrayList<>();
        for (int i = 0; i < m; i++) {
            testSamples.add(parseLine(sc.nextLine(), false));
        }

        for (Sample s : testSamples) {
            for (int i = 0; i < 5; i++) {
                if (s.isNan[i]) {
                    s.features[i] = means[i];
                } else if (isOutlier(s.features[i], i)) {
                    s.features[i] = medians[i];
                }
            }
        }
        
        // 预测
        for (Sample s : testSamples) {
            double z = w[0];
            for (int i = 0; i < 5; i++) {
                z += w[i + 1] * s.features[i];
            }
            double p = 1.0 / (1.0 + Math.exp(-z));
            System.out.println(p >= 0.5 ? 1 : 0);
        }
    }
}
import math
import statistics

# 检查是否为异常值
def is_outlier(val, feature_idx):
    if feature_idx <= 1:  # writes, reads
        return val < 0
    if feature_idx <= 3:  # avg_write_ms, avg_read_ms
        return val < 0 or val > 1000
    if feature_idx == 4:  # years
        return val < 0 or val > 20
    return False

# 解析行数据
def parse_line(line, is_training):
    parts = line.strip().split(',')
    device_id = parts[0]
    features = []
    is_nan = []
    for i in range(1, 6):
        if parts[i] == 'NaN':
            features.append(0.0) # 临时值
            is_nan.append(True)
        else:
            features.append(float(parts[i]))
            is_nan.append(False)
            
    status = -1
    valid_label = True
    if is_training:
        if len(parts) < 7:
            valid_label = False
        else:
            try:
                status = int(parts[6])
                if status not in [0, 1]:
                    valid_label = False
            except ValueError:
                valid_label = False
    return device_id, features, status, valid_label, is_nan

def solve():
    n = int(input())
    training_samples = []
    for _ in range(n):
        line = input()
        device_id, features, status, valid_label, is_nan = parse_line(line, True)
        if valid_label:
            training_samples.append({'id': device_id, 'features': features, 'status': status, 'is_nan': is_nan})

    means = [0.0] * 5
    medians = [0.0] * 5

    for i in range(5):
        valid_values = []
        for sample in training_samples:
            if not sample['is_nan'][i] and not is_outlier(sample['features'][i], i):
                valid_values.append(sample['features'][i])
        
        if valid_values:
            means[i] = statistics.mean(valid_values)
            medians[i] = statistics.median(valid_values)

    def clean_data(samples):
        for sample in samples:
            for i in range(5):
                if sample['is_nan'][i]:
                    sample['features'][i] = means[i]
                elif is_outlier(sample['features'][i], i):
                    sample['features'][i] = medians[i]

    clean_data(training_samples)

    w = [0.0] * 6  # w0...w5
    learning_rate = 0.01
    iterations = 100
    num_train_samples = len(training_samples)

    for _ in range(iterations):
        gradients = [0.0] * 6
        if num_train_samples == 0:
            break
            
        for sample in training_samples:
            z = w[0] + sum(w[i+1] * sample['features'][i] for i in range(5))
            try:
                p = 1.0 / (1.0 + math.exp(-z))
            except OverflowError:
                p = 0.0 if z < 0 else 1.0
                
            error = p - sample['status']
            gradients[0] += error
            for i in range(5):
                gradients[i+1] += error * sample['features'][i]
        
        for i in range(6):
            w[i] -= learning_rate * gradients[i] / num_train_samples

    m = int(input())
    test_samples = []
    for _ in range(m):
        line = input()
        device_id, features, _, _, is_nan = parse_line(line, False)
        test_samples.append({'id': device_id, 'features': features, 'is_nan': is_nan})
        
    clean_data(test_samples)

    for sample in test_samples:
        z = w[0] + sum(w[i+1] * sample['features'][i] for i in range(5))
        try:
            p = 1.0 / (1.0 + math.exp(-z))
        except OverflowError:
            p = 0.0 if z < 0 else 1.0
        
        print(1 if p >= 0.5 else 0)

solve()

算法及复杂度

  • 算法:机器学习、逻辑回归、批量梯度下降
  • 时间复杂度,其中 是迭代次数(), 是有效训练样本数, 是特征数量()。第一项是模型训练的时间,第二项是为计算中位数而对每个特征的有效值进行排序的时间。
  • 空间复杂度,需要存储所有训练和测试样本的数据,其中 是测试样本数。