ID3决策树训练

题意

给定 个训练样本,每个样本有 个二值特征和一个二值标签(0=正常,1=退化)。要求用 ID3 算法构建决策树,然后对 个查询样本进行预测。

思路

什么是 ID3 决策树?

ID3 的核心思想很简单:每次选一个"最有区分力"的特征把数据一分为二,递归建树,直到不需要再分为止。

那怎么衡量"区分力"?用信息增益

信息增益怎么算?

先回顾信息熵。对于一组标签,信息熵定义为:

$$

其中 是各类别的占比。熵越大,数据越"混乱";熵为 0,说明所有样本同类。

选特征 把数据分成""和""两组后,加权平均熵为:

$$

信息增益就是分裂前后熵的差:。增益越大,说明这个特征越能把不同类别分开。

建树过程

递归建树,每一步:

  1. 终止条件:所有标签相同,直接返回该标签;没有可用特征或没有正增益,返回多数类(相同时返回 0);子集为空,返回父节点的多数类。
  2. 选特征:遍历所有可用特征,算信息增益,选最大的。增益相同时选下标更小的。
  3. 分裂:按选出的特征把数据分成两组,对两组分别递归。

预测

从根节点出发,看当前节点对应的特征值是 0 还是 1,走向左子树或右子树,直到叶节点,返回叶节点的标签。

复杂度

建树时每层最多遍历 个特征,树深度最多 ,每层扫描所有样本,时间 。预测每个查询 ,总查询

代码

import math
import sys
from collections import Counter

def entropy(labels):
    n = len(labels)
    if n == 0:
        return 0.0
    cnt = Counter(labels)
    ent = 0.0
    for c in cnt.values():
        p = c / n
        if p > 0:
            ent -= p * math.log2(p)
    return ent

def majority_label(labels):
    cnt = Counter(labels)
    if cnt.get(0, 0) >= cnt.get(1, 0):
        return 0
    return 1

def build_tree(data, features_available):
    if not data:
        return None

    labels = [d[1] for d in data]

    if len(set(labels)) == 1:
        return labels[0]

    if not features_available:
        return majority_label(labels)

    base_ent = entropy(labels)
    best_gain = -1
    best_feat = -1
    n = len(data)

    for f in sorted(features_available):
        left = [d[1] for d in data if d[0][f] == 0]
        right = [d[1] for d in data if d[0][f] == 1]
        weighted_ent = (len(left) / n) * entropy(left) + (len(right) / n) * entropy(right)
        gain = base_ent - weighted_ent
        if gain > best_gain:
            best_gain = gain
            best_feat = f

    if best_gain <= 1e-12:
        return majority_label(labels)

    left_data = [d for d in data if d[0][best_feat] == 0]
    right_data = [d for d in data if d[0][best_feat] == 1]
    new_features = features_available - {best_feat}

    left_child = build_tree(left_data, new_features)
    right_child = build_tree(right_data, new_features)

    maj = majority_label(labels)
    if left_child is None:
        left_child = maj
    if right_child is None:
        right_child = maj

    return (best_feat, left_child, right_child)

def predict(tree, features):
    if isinstance(tree, int):
        return tree
    feat, left, right = tree
    if features[feat] == 0:
        return predict(left, features)
    else:
        return predict(right, features)

def main():
    input_data = sys.stdin.read().split()
    idx = 0
    n = int(input_data[idx]); idx += 1
    m = int(input_data[idx]); idx += 1

    data = []
    for i in range(n):
        feats = []
        for j in range(m):
            feats.append(int(input_data[idx])); idx += 1
        label = int(input_data[idx]); idx += 1
        data.append((feats, label))

    q = int(input_data[idx]); idx += 1
    tree = build_tree(data, set(range(m)))

    results = []
    for i in range(q):
        feats = []
        for j in range(m):
            feats.append(int(input_data[idx])); idx += 1
        results.append(str(predict(tree, feats)))

    print('\n'.join(results))

main()
#include <bits/stdc++.h>
using namespace std;

struct Node {
    int type; // 0=leaf, 1=internal
    int label, feat, left, right;
};

vector<Node> tree;
int n, m;

double entropy(const vector<int>& labels) {
    int n = labels.size();
    if (n == 0) return 0.0;
    int cnt1 = 0;
    for (int l : labels) cnt1 += l;
    int cnt0 = n - cnt1;
    double ent = 0.0;
    if (cnt0 > 0) { double p = (double)cnt0 / n; ent -= p * log2(p); }
    if (cnt1 > 0) { double p = (double)cnt1 / n; ent -= p * log2(p); }
    return ent;
}

int majorityLabel(const vector<int>& labels) {
    int cnt1 = 0;
    for (int l : labels) cnt1 += l;
    return (int)labels.size() - cnt1 >= cnt1 ? 0 : 1;
}

int buildTree(const vector<vector<int>>& feats, const vector<int>& indices, set<int>& available) {
    int idx = tree.size();
    tree.push_back({});

    if (indices.empty()) {
        tree[idx] = {0, 0, -1, -1, -1};
        return idx;
    }

    vector<int> labels;
    for (int i : indices) labels.push_back(feats[i][m]);

    if (set<int>(labels.begin(), labels.end()).size() == 1) {
        tree[idx] = {0, labels[0], -1, -1, -1};
        return idx;
    }

    if (available.empty()) {
        tree[idx] = {0, majorityLabel(labels), -1, -1, -1};
        return idx;
    }

    double baseEnt = entropy(labels);
    double bestGain = -1;
    int bestFeat = -1;
    int sz = indices.size();

    for (int f : available) {
        vector<int> left, right;
        for (int i : indices) {
            if (feats[i][f] == 0) left.push_back(feats[i][m]);
            else right.push_back(feats[i][m]);
        }
        double wEnt = ((double)left.size()/sz)*entropy(left) + ((double)right.size()/sz)*entropy(right);
        double gain = baseEnt - wEnt;
        if (gain > bestGain) { bestGain = gain; bestFeat = f; }
    }

    if (bestGain <= 1e-12) {
        tree[idx] = {0, majorityLabel(labels), -1, -1, -1};
        return idx;
    }

    int maj = majorityLabel(labels);
    vector<int> leftIdx, rightIdx;
    for (int i : indices) {
        if (feats[i][bestFeat] == 0) leftIdx.push_back(i);
        else rightIdx.push_back(i);
    }

    available.erase(bestFeat);
    int leftChild = leftIdx.empty() ? (int)(tree.push_back({0, maj, -1, -1, -1}), tree.size()-1) : buildTree(feats, leftIdx, available);
    int rightChild = rightIdx.empty() ? (int)(tree.push_back({0, maj, -1, -1, -1}), tree.size()-1) : buildTree(feats, rightIdx, available);
    available.insert(bestFeat);

    tree[idx] = {1, -1, bestFeat, leftChild, rightChild};
    return idx;
}

int predict(int node, const vector<int>& feat) {
    if (tree[node].type == 0) return tree[node].label;
    return predict(feat[tree[node].feat] == 0 ? tree[node].left : tree[node].right, feat);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    vector<vector<int>> data(n, vector<int>(m+1));
    for (int i = 0; i < n; i++)
        for (int j = 0; j <= m; j++)
            cin >> data[i][j];

    vector<int> indices(n);
    iota(indices.begin(), indices.end(), 0);
    set<int> avail;
    for (int i = 0; i < m; i++) avail.insert(i);
    tree.reserve(1000);
    int root = buildTree(data, indices, avail);

    int q; cin >> q;
    while (q--) {
        vector<int> feat(m);
        for (int j = 0; j < m; j++) cin >> feat[j];
        cout << predict(root, feat) << '\n';
    }
}