ID3决策树训练
题意
给定 个训练样本,每个样本有
个二值特征和一个二值标签(0=正常,1=退化)。要求用 ID3 算法构建决策树,然后对
个查询样本进行预测。
思路
什么是 ID3 决策树?
ID3 的核心思想很简单:每次选一个"最有区分力"的特征把数据一分为二,递归建树,直到不需要再分为止。
那怎么衡量"区分力"?用信息增益。
信息增益怎么算?
先回顾信息熵。对于一组标签,信息熵定义为:
$$
其中 是各类别的占比。熵越大,数据越"混乱";熵为 0,说明所有样本同类。
选特征 把数据分成"
"和"
"两组后,加权平均熵为:
$$
信息增益就是分裂前后熵的差:。增益越大,说明这个特征越能把不同类别分开。
建树过程
递归建树,每一步:
- 终止条件:所有标签相同,直接返回该标签;没有可用特征或没有正增益,返回多数类(相同时返回 0);子集为空,返回父节点的多数类。
- 选特征:遍历所有可用特征,算信息增益,选最大的。增益相同时选下标更小的。
- 分裂:按选出的特征把数据分成两组,对两组分别递归。
预测
从根节点出发,看当前节点对应的特征值是 0 还是 1,走向左子树或右子树,直到叶节点,返回叶节点的标签。
复杂度
建树时每层最多遍历 个特征,树深度最多
,每层扫描所有样本,时间
。预测每个查询
,总查询
。
代码
import math
import sys
from collections import Counter
def entropy(labels):
n = len(labels)
if n == 0:
return 0.0
cnt = Counter(labels)
ent = 0.0
for c in cnt.values():
p = c / n
if p > 0:
ent -= p * math.log2(p)
return ent
def majority_label(labels):
cnt = Counter(labels)
if cnt.get(0, 0) >= cnt.get(1, 0):
return 0
return 1
def build_tree(data, features_available):
if not data:
return None
labels = [d[1] for d in data]
if len(set(labels)) == 1:
return labels[0]
if not features_available:
return majority_label(labels)
base_ent = entropy(labels)
best_gain = -1
best_feat = -1
n = len(data)
for f in sorted(features_available):
left = [d[1] for d in data if d[0][f] == 0]
right = [d[1] for d in data if d[0][f] == 1]
weighted_ent = (len(left) / n) * entropy(left) + (len(right) / n) * entropy(right)
gain = base_ent - weighted_ent
if gain > best_gain:
best_gain = gain
best_feat = f
if best_gain <= 1e-12:
return majority_label(labels)
left_data = [d for d in data if d[0][best_feat] == 0]
right_data = [d for d in data if d[0][best_feat] == 1]
new_features = features_available - {best_feat}
left_child = build_tree(left_data, new_features)
right_child = build_tree(right_data, new_features)
maj = majority_label(labels)
if left_child is None:
left_child = maj
if right_child is None:
right_child = maj
return (best_feat, left_child, right_child)
def predict(tree, features):
if isinstance(tree, int):
return tree
feat, left, right = tree
if features[feat] == 0:
return predict(left, features)
else:
return predict(right, features)
def main():
input_data = sys.stdin.read().split()
idx = 0
n = int(input_data[idx]); idx += 1
m = int(input_data[idx]); idx += 1
data = []
for i in range(n):
feats = []
for j in range(m):
feats.append(int(input_data[idx])); idx += 1
label = int(input_data[idx]); idx += 1
data.append((feats, label))
q = int(input_data[idx]); idx += 1
tree = build_tree(data, set(range(m)))
results = []
for i in range(q):
feats = []
for j in range(m):
feats.append(int(input_data[idx])); idx += 1
results.append(str(predict(tree, feats)))
print('\n'.join(results))
main()
#include <bits/stdc++.h>
using namespace std;
struct Node {
int type; // 0=leaf, 1=internal
int label, feat, left, right;
};
vector<Node> tree;
int n, m;
double entropy(const vector<int>& labels) {
int n = labels.size();
if (n == 0) return 0.0;
int cnt1 = 0;
for (int l : labels) cnt1 += l;
int cnt0 = n - cnt1;
double ent = 0.0;
if (cnt0 > 0) { double p = (double)cnt0 / n; ent -= p * log2(p); }
if (cnt1 > 0) { double p = (double)cnt1 / n; ent -= p * log2(p); }
return ent;
}
int majorityLabel(const vector<int>& labels) {
int cnt1 = 0;
for (int l : labels) cnt1 += l;
return (int)labels.size() - cnt1 >= cnt1 ? 0 : 1;
}
int buildTree(const vector<vector<int>>& feats, const vector<int>& indices, set<int>& available) {
int idx = tree.size();
tree.push_back({});
if (indices.empty()) {
tree[idx] = {0, 0, -1, -1, -1};
return idx;
}
vector<int> labels;
for (int i : indices) labels.push_back(feats[i][m]);
if (set<int>(labels.begin(), labels.end()).size() == 1) {
tree[idx] = {0, labels[0], -1, -1, -1};
return idx;
}
if (available.empty()) {
tree[idx] = {0, majorityLabel(labels), -1, -1, -1};
return idx;
}
double baseEnt = entropy(labels);
double bestGain = -1;
int bestFeat = -1;
int sz = indices.size();
for (int f : available) {
vector<int> left, right;
for (int i : indices) {
if (feats[i][f] == 0) left.push_back(feats[i][m]);
else right.push_back(feats[i][m]);
}
double wEnt = ((double)left.size()/sz)*entropy(left) + ((double)right.size()/sz)*entropy(right);
double gain = baseEnt - wEnt;
if (gain > bestGain) { bestGain = gain; bestFeat = f; }
}
if (bestGain <= 1e-12) {
tree[idx] = {0, majorityLabel(labels), -1, -1, -1};
return idx;
}
int maj = majorityLabel(labels);
vector<int> leftIdx, rightIdx;
for (int i : indices) {
if (feats[i][bestFeat] == 0) leftIdx.push_back(i);
else rightIdx.push_back(i);
}
available.erase(bestFeat);
int leftChild = leftIdx.empty() ? (int)(tree.push_back({0, maj, -1, -1, -1}), tree.size()-1) : buildTree(feats, leftIdx, available);
int rightChild = rightIdx.empty() ? (int)(tree.push_back({0, maj, -1, -1, -1}), tree.size()-1) : buildTree(feats, rightIdx, available);
available.insert(bestFeat);
tree[idx] = {1, -1, bestFeat, leftChild, rightChild};
return idx;
}
int predict(int node, const vector<int>& feat) {
if (tree[node].type == 0) return tree[node].label;
return predict(feat[tree[node].feat] == 0 ? tree[node].left : tree[node].right, feat);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m;
vector<vector<int>> data(n, vector<int>(m+1));
for (int i = 0; i < n; i++)
for (int j = 0; j <= m; j++)
cin >> data[i][j];
vector<int> indices(n);
iota(indices.begin(), indices.end(), 0);
set<int> avail;
for (int i = 0; i < m; i++) avail.insert(i);
tree.reserve(1000);
int root = buildTree(data, indices, avail);
int q; cin >> q;
while (q--) {
vector<int> feat(m);
for (int j = 0; j < m; j++) cin >> feat[j];
cout << predict(root, feat) << '\n';
}
}

京公网安备 11010502036488号