信用评分模型优化

[题目链接](https://www.nowcoder.com/practice/27a2c6a4127c4c7080d34fc017d337bc)

思路

给定一组申请者的特征数据和信用评分结果('G' 表示良好,'B' 表示不良),需要计算每个特征的信息增益比(Information Gain Ratio),输出信息增益比最高的特征索引。

信息增益比

信息增益比是 C4.5 决策树算法使用的特征选择准则,定义为:

$$

其中:

  • 是特征 对数据集 信息增益
  • 是特征 固有值(Intrinsic Value),也称分裂信息。

信息熵

数据集 的信息熵定义为:

$$

其中 是第 类样本在 中的比例, 是类别数。

信息增益

特征 将数据集 按取值划分为若干子集 ,则信息增益为:

$$

即:原始熵减去按特征划分后的条件熵。

固有值(分裂信息)

$$

固有值衡量特征 本身的"分散程度"。取值越多、分布越均匀,固有值越大。用信息增益除以固有值,可以校正信息增益对多值特征的偏好。

算法步骤

  1. 解析输入的二维列表,最后一列为标签,其余列为特征。
  2. 计算标签的整体信息熵
  3. 对每个特征,按其取值将样本分组,计算条件熵、信息增益、固有值、信息增益比。
  4. 输出信息增益比最大的特征索引。

样例演示

输入 5 条记录,3 个特征(年龄、年收入、信用卡余额),标签为 G/B。

标签分布:3 个 G,2 个 B,整体熵

由于三个特征的取值都各不相同(5 个样本 5 个不同值),每个子集只有 1 个样本,条件熵均为 0,信息增益均等于 。同时三个特征的固有值也相同(),因此信息增益比相同,输出第一个特征的索引 0。

复杂度分析

设样本数为 ,特征数为

  • 时间复杂度:,对每个特征遍历所有样本一次。
  • 空间复杂度:,存储输入数据。

代码

#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <cmath>
using namespace std;

double entropy(const vector<string>& labels) {
    if (labels.empty()) return 0.0;
    map<string, int> counts;
    for (auto& l : labels) counts[l]++;
    double ent = 0.0;
    int n = labels.size();
    for (auto& [k, c] : counts) {
        double p = (double)c / n;
        if (p > 0) ent -= p * log2(p);
    }
    return ent;
}

int main() {
    string line;
    getline(cin, line);

    // 解析 Python 风格的二维列表
    vector<vector<string>> data;
    int depth = 0, start = -1;
    for (int i = 0; i < (int)line.size(); i++) {
        if (line[i] == '[') {
            depth++;
            if (depth == 2) start = i + 1;
        } else if (line[i] == ']') {
            depth--;
            if (depth == 1) {
                string sub = line.substr(start, i - start);
                vector<string> row;
                string token;
                for (int j = 0; j <= (int)sub.size(); j++) {
                    if (j == (int)sub.size() || sub[j] == ',') {
                        int a = 0, b = token.size();
                        while (a < b && (token[a] == ' ' || token[a] == '\'' || token[a] == '"')) a++;
                        while (b > a && (token[b-1] == ' ' || token[b-1] == '\'' || token[b-1] == '"')) b--;
                        row.push_back(token.substr(a, b - a));
                        token.clear();
                    } else {
                        token += sub[j];
                    }
                }
                data.push_back(row);
            }
        }
    }

    int n = data.size();
    int numFeatures = data[0].size() - 1;
    vector<string> labels(n);
    for (int i = 0; i < n; i++) labels[i] = data[i].back();
    double hY = entropy(labels);

    int bestIdx = 0;
    double bestRatio = -1.0;
    for (int f = 0; f < numFeatures; f++) {
        map<string, vector<string>> groups;
        for (int i = 0; i < n; i++)
            groups[data[i][f]].push_back(labels[i]);

        double hCond = 0.0;
        for (auto& [k, subset] : groups) {
            double p = (double)subset.size() / n;
            hCond += p * entropy(subset);
        }
        double gain = hY - hCond;

        double splitInfo = 0.0;
        for (auto& [k, subset] : groups) {
            double p = (double)subset.size() / n;
            if (p > 0) splitInfo -= p * log2(p);
        }

        double ratio = (splitInfo == 0) ? 0.0 : gain / splitInfo;
        if (ratio > bestRatio) {
            bestRatio = ratio;
            bestIdx = f;
        }
    }

    cout << bestIdx << endl;
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        String line = sc.nextLine().trim();

        // 解析 Python 风格的二维列表
        List<List<String>> data = new ArrayList<>();
        line = line.substring(1, line.length() - 1);
        int depth = 0, start = -1;
        for (int i = 0; i < line.length(); i++) {
            char c = line.charAt(i);
            if (c == '[') {
                depth++;
                if (depth == 1) start = i + 1;
            } else if (c == ']') {
                depth--;
                if (depth == 0) {
                    String sub = line.substring(start, i).trim();
                    List<String> row = new ArrayList<>();
                    for (String p : sub.split(",")) {
                        row.add(p.trim().replace("'", "").replace("\"", ""));
                    }
                    data.add(row);
                }
            }
        }

        int n = data.size();
        int numFeatures = data.get(0).size() - 1;
        String[] labels = new String[n];
        for (int i = 0; i < n; i++)
            labels[i] = data.get(i).get(data.get(i).size() - 1);
        double hY = entropy(labels);

        int bestIdx = 0;
        double bestRatio = -1.0;
        for (int f = 0; f < numFeatures; f++) {
            Map<String, List<String>> groups = new LinkedHashMap<>();
            for (int i = 0; i < n; i++)
                groups.computeIfAbsent(data.get(i).get(f), k -> new ArrayList<>()).add(labels[i]);

            double hCond = 0.0;
            for (List<String> subset : groups.values()) {
                double p = (double) subset.size() / n;
                hCond += p * entropy(subset.toArray(new String[0]));
            }
            double gain = hY - hCond;

            double splitInfo = 0.0;
            for (List<String> subset : groups.values()) {
                double p = (double) subset.size() / n;
                if (p > 0) splitInfo -= p * (Math.log(p) / Math.log(2));
            }

            double ratio = (splitInfo == 0) ? 0.0 : gain / splitInfo;
            if (ratio > bestRatio) {
                bestRatio = ratio;
                bestIdx = f;
            }
        }
        System.out.println(bestIdx);
    }

    static double entropy(String[] labels) {
        if (labels.length == 0) return 0.0;
        Map<String, Integer> counts = new HashMap<>();
        for (String l : labels) counts.merge(l, 1, Integer::sum);
        double ent = 0.0;
        for (int c : counts.values()) {
            double p = (double) c / labels.length;
            if (p > 0) ent -= p * (Math.log(p) / Math.log(2));
        }
        return ent;
    }
}
import math
from collections import Counter

def entropy(labels):
    n = len(labels)
    if n == 0:
        return 0.0
    counts = Counter(labels)
    return -sum((c / n) * math.log2(c / n) for c in counts.values() if c > 0)

def info_gain_ratio(data, feat_idx):
    n = len(data)
    labels = [row[-1] for row in data]
    h_y = entropy(labels)

    # 按特征值分组
    groups = {}
    for row in data:
        groups.setdefault(row[feat_idx], []).append(row[-1])

    # 条件熵
    h_cond = sum(len(g) / n * entropy(g) for g in groups.values())
    gain = h_y - h_cond

    # 固有值(分裂信息)
    split_info = -sum((len(g) / n) * math.log2(len(g) / n)
                      for g in groups.values() if len(g) > 0)

    return gain / split_info if split_info > 0 else 0.0

data = eval(input())
num_features = len(data[0]) - 1
best_idx = max(range(num_features), key=lambda i: info_gain_ratio(data, i))
print(best_idx)