题目链接
题目描述
给定一批包含 个二值特征和二分类标签(0 或 1)的样本,需要基于 ID3 决策树算法训练一个二叉分类器,并对
条查询样本进行预测。
规则与细节:
- 划分准则:信息增益 (Information Gain)。在当前节点,从尚未使用过的特征中,选择信息增益最大的进行划分。
- 并列处理:若多个特征的信息增益相同,选择特征下标更小的那个。
- 终止条件(生成叶子节点):
- 若当前节点包含的样本标签完全相同,则成为该标签的叶子节点。
- 若所有特征都已用于划分,或没有任何特征能带来正的信息增益,则成为“多数标签”的叶子节点(若标签平票,则为 0)。
- 空分支处理:若某次划分导致一侧的样本集为空,该分支直接成为其父节点样本集的“多数标签”叶子节点。
- 预测:从根节点开始,根据样本的特征值(0 或 1)向左或向右遍历,直到到达叶子节点,其值即为预测结果。
解题思路
本题要求从头实现一个 ID3 决策树。核心是递归地构建树结构,并在每个节点上选择最优的特征进行数据划分。
1. 核心概念:信息增益
ID3 算法通过信息增益来选择最优的划分特征。
-
信息熵 (Entropy):衡量一个数据集纯度的指标。熵越小,数据集越纯(即标签越一致)。对于一个样本集
,其信息熵
定义为:
其中
是类别
在样本集
中所占的比例。对于本题的二分类问题,公式为
。
-
信息增益 (Information Gain):表示在得知特征
的信息后,数据集
的不确定性减少的程度。计算公式为:
对于本题的二值特征,公式简化为:
其中
和
分别是特征
的值为 0 和 1 的样本子集。信息增益越大的特征,划分效果越好。
2. 递归构建决策树
我们可以定义一个递归函数 build_tree(samples, used_features)
来构建决策树。
samples
是当前节点包含的训练样本。used_features
是从根节点到当前节点的路径上已经使用过的特征集合。
递归的终止条件(即创建叶子节点):
samples
中的所有样本标签相同:创建一个叶子节点,预测值为该标签。used_features
包含了所有特征:创建一个叶子节点,预测值为samples
中的多数标签(平票为 0)。- 遍历所有 未使用 的特征,计算信息增益。如果最大的信息增益
:创建一个叶子节点,预测值为
samples
中的多数标签(平票为 0)。
递归划分步骤:
- 在未使用的特征中,找到信息增益最大的特征
(并列时选下标小的)。
- 创建一个内部节点,记录划分特征为
。
- 将
samples
划分成两个子集:(特征
为 0) 和
(特征
为 1)。
- 处理空分支:如果
或
为空,则对应的子节点直接成为叶子节点,其预测值为父节点
samples
的多数标签。 - 如果子集非空,则递归调用
build_tree
创建子节点:- 左子节点:
build_tree(S_0, used_features + {A*})
- 右子节点:
build_tree(S_1, used_features + {A*})
- 左子节点:
3. 预测
预测过程相对简单。对于一个待预测的样本,从树的根节点开始:
- 如果当前是叶子节点,直接返回其预测值。
- 如果当前是内部节点,获取其划分特征
。
- 检查待预测样本中特征
的值:
- 如果值为 0,则移动到左子节点。
- 如果值为 1,则移动到右子节点。
- 重复此过程,直到到达一个叶子节点。
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <map>
#include <algorithm>
using namespace std;
// 样本数据结构
struct Sample {
vector<int> features;
int label;
};
// 决策树节点
struct Node {
bool is_leaf = false;
int prediction = -1;
int feature_index = -1;
Node* left = nullptr;
Node* right = nullptr;
~Node() {
delete left;
delete right;
}
};
// 计算给定样本索引的信息熵
double calculate_entropy(const vector<int>& sample_indices, const vector<Sample>& all_samples) {
if (sample_indices.empty()) {
return 0.0;
}
map<int, int> counts;
for (int idx : sample_indices) {
counts[all_samples[idx].label]++;
}
double entropy = 0.0;
for (auto const& [label, count] : counts) {
double p = (double)count / sample_indices.size();
if (p > 0) {
entropy -= p * log2(p);
}
}
return entropy;
}
// 计算多数标签(平票返回0)
int get_majority_label(const vector<int>& sample_indices, const vector<Sample>& all_samples) {
map<int, int> counts;
for (int idx : sample_indices) {
counts[all_samples[idx].label]++;
}
if (counts[0] >= counts[1]) {
return 0;
}
return 1;
}
// 递归构建决策树
Node* build_tree(const vector<int>& sample_indices, vector<bool>& used_features, const vector<Sample>& all_samples, int m) {
Node* node = new Node();
// 终止条件1:标签全同
int first_label = all_samples[sample_indices[0]].label;
bool all_same = all_of(sample_indices.begin(), sample_indices.end(),
[&](int idx){ return all_samples[idx].label == first_label; });
if (all_same) {
node->is_leaf = true;
node->prediction = first_label;
return node;
}
// 终止条件2:特征用尽
bool all_features_used = all_of(used_features.begin(), used_features.end(), [](bool v){ return v; });
if (all_features_used) {
node->is_leaf = true;
node->prediction = get_majority_label(sample_indices, all_samples);
return node;
}
// 寻找最佳划分特征
double best_gain = 0.0;
int best_feature_index = -1;
double parent_entropy = calculate_entropy(sample_indices, all_samples);
for (int i = 0; i < m; ++i) {
if (used_features[i]) continue;
vector<int> left_indices, right_indices;
for (int idx : sample_indices) {
if (all_samples[idx].features[i] == 0) left_indices.push_back(idx);
else right_indices.push_back(idx);
}
double left_entropy = calculate_entropy(left_indices, all_samples);
double right_entropy = calculate_entropy(right_indices, all_samples);
double total_size = sample_indices.size();
double conditional_entropy = (left_indices.size() / total_size) * left_entropy + (right_indices.size() / total_size) * right_entropy;
double gain = parent_entropy - conditional_entropy;
if (gain > best_gain) {
best_gain = gain;
best_feature_index = i;
}
}
// 终止条件3:无正增益
if (best_gain <= 0) {
node->is_leaf = true;
node->prediction = get_majority_label(sample_indices, all_samples);
return node;
}
// 递归划分
node->feature_index = best_feature_index;
used_features[best_feature_index] = true;
vector<int> left_indices, right_indices;
for (int idx : sample_indices) {
if (all_samples[idx].features[best_feature_index] == 0) left_indices.push_back(idx);
else right_indices.push_back(idx);
}
int majority_label = get_majority_label(sample_indices, all_samples);
if (left_indices.empty()) {
node->left = new Node{true, majority_label};
} else {
node->left = build_tree(left_indices, used_features, all_samples, m);
}
if (right_indices.empty()) {
node->right = new Node{true, majority_label};
} else {
node->right = build_tree(right_indices, used_features, all_samples, m);
}
used_features[best_feature_index] = false; // 回溯
return node;
}
// 预测
int predict(const Node* node, const vector<int>& features) {
while (!node->is_leaf) {
if (features[node->feature_index] == 0) {
node = node->left;
} else {
node = node->right;
}
}
return node->prediction;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int n, m;
cin >> n >> m;
vector<Sample> training_data(n);
vector<int> initial_indices(n);
for (int i = 0; i < n; ++i) {
training_data[i].features.resize(m);
for (int j = 0; j < m; ++j) {
cin >> training_data[i].features[j];
}
cin >> training_data[i].label;
initial_indices[i] = i;
}
vector<bool> used_features(m, false);
Node* root = build_tree(initial_indices, used_features, training_data, m);
int q;
cin >> q;
for (int i = 0; i < q; ++i) {
vector<int> query_features(m);
for (int j = 0; j < m; ++j) {
cin >> query_features[j];
}
cout << predict(root, query_features) << "\n";
}
delete root;
return 0;
}
import java.util.*;
import java.util.stream.Collectors;
public class Main {
// 样本数据结构
static class Sample {
int[] features;
int label;
Sample(int m) {
features = new int[m];
}
}
// 决策树节点
static class Node {
boolean isLeaf = false;
int prediction = -1;
int featureIndex = -1;
Node left = null;
Node right = null;
}
// 计算给定样本索引的信息熵
private static double calculateEntropy(List<Integer> sampleIndices, List<Sample> allSamples) {
if (sampleIndices.isEmpty()) {
return 0.0;
}
Map<Integer, Long> counts = sampleIndices.stream()
.collect(Collectors.groupingBy(idx -> allSamples.get(idx).label, Collectors.counting()));
double entropy = 0.0;
for (long count : counts.values()) {
double p = (double) count / sampleIndices.size();
if (p > 0) {
entropy -= p * (Math.log(p) / Math.log(2));
}
}
return entropy;
}
// 计算多数标签(平票返回0)
private static int getMajorityLabel(List<Integer> sampleIndices, List<Sample> allSamples) {
long count0 = sampleIndices.stream().filter(idx -> allSamples.get(idx).label == 0).count();
long count1 = sampleIndices.size() - count0;
return (count0 >= count1) ? 0 : 1;
}
// 递归构建决策树
private static Node buildTree(List<Integer> sampleIndices, boolean[] usedFeatures, List<Sample> allSamples, int m) {
Node node = new Node();
// 终止条件1:标签全同
int firstLabel = allSamples.get(sampleIndices.get(0)).label;
if (sampleIndices.stream().allMatch(idx -> allSamples.get(idx).label == firstLabel)) {
node.isLeaf = true;
node.prediction = firstLabel;
return node;
}
// 终止条件2:特征用尽
boolean allFeaturesUsed = true;
for (boolean used : usedFeatures) {
if (!used) {
allFeaturesUsed = false;
break;
}
}
if (allFeaturesUsed) {
node.isLeaf = true;
node.prediction = getMajorityLabel(sampleIndices, allSamples);
return node;
}
// 寻找最佳划分特征
double bestGain = 0.0;
int bestFeatureIndex = -1;
double parentEntropy = calculateEntropy(sampleIndices, allSamples);
for (int i = 0; i < m; i++) {
if (usedFeatures[i]) continue;
List<Integer> leftIndices = new ArrayList<>();
List<Integer> rightIndices = new ArrayList<>();
for (int idx : sampleIndices) {
if (allSamples.get(idx).features[i] == 0) leftIndices.add(idx);
else rightIndices.add(idx);
}
double leftEntropy = calculateEntropy(leftIndices, allSamples);
double rightEntropy = calculateEntropy(rightIndices, allSamples);
double totalSize = sampleIndices.size();
double conditionalEntropy = (leftIndices.size() / totalSize) * leftEntropy + (rightIndices.size() / totalSize) * rightEntropy;
double gain = parentEntropy - conditionalEntropy;
if (gain > bestGain) {
bestGain = gain;
bestFeatureIndex = i;
}
}
// 终止条件3:无正增益
if (bestGain <= 1e-9) { // 使用一个小的 epsilon 来比较浮点数
node.isLeaf = true;
node.prediction = getMajorityLabel(sampleIndices, allSamples);
return node;
}
// 递归划分
node.featureIndex = bestFeatureIndex;
usedFeatures[bestFeatureIndex] = true;
List<Integer> leftIndices = new ArrayList<>();
List<Integer> rightIndices = new ArrayList<>();
for (int idx : sampleIndices) {
if (allSamples.get(idx).features[bestFeatureIndex] == 0) leftIndices.add(idx);
else rightIndices.add(idx);
}
int majorityLabel = getMajorityLabel(sampleIndices, allSamples);
if (leftIndices.isEmpty()) {
node.left = new Node();
node.left.isLeaf = true;
node.left.prediction = majorityLabel;
} else {
node.left = buildTree(leftIndices, usedFeatures, allSamples, m);
}
if (rightIndices.isEmpty()) {
node.right = new Node();
node.right.isLeaf = true;
node.right.prediction = majorityLabel;
} else {
node.right = buildTree(rightIndices, usedFeatures, allSamples, m);
}
usedFeatures[bestFeatureIndex] = false; // 回溯
return node;
}
// 预测
private static int predict(Node node, int[] features) {
while (!node.isLeaf) {
if (features[node.featureIndex] == 0) {
node = node.left;
} else {
node = node.right;
}
}
return node.prediction;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
List<Sample> trainingData = new ArrayList<>();
List<Integer> initialIndices = new ArrayList<>();
for (int i = 0; i < n; i++) {
Sample sample = new Sample(m);
for (int j = 0; j < m; j++) {
sample.features[j] = sc.nextInt();
}
sample.label = sc.nextInt();
trainingData.add(sample);
initialIndices.add(i);
}
boolean[] usedFeatures = new boolean[m];
Node root = buildTree(initialIndices, usedFeatures, trainingData, m);
int q = sc.nextInt();
for (int i = 0; i < q; i++) {
int[] queryFeatures = new int[m];
for (int j = 0; j < m; j++) {
queryFeatures[j] = sc.nextInt();
}
System.out.println(predict(root, queryFeatures));
}
}
}
import math
from collections import Counter
# 决策树节点
class Node:
def __init__(self, is_leaf=False, prediction=-1, feature_index=-1, left=None, right=None):
self.is_leaf = is_leaf
self.prediction = prediction
self.feature_index = feature_index
self.left = left
self.right = right
# 计算信息熵
def calculate_entropy(sample_indices, all_samples):
if not sample_indices:
return 0
labels = [all_samples[i][1] for i in sample_indices]
counts = Counter(labels)
entropy = 0
total_size = len(sample_indices)
for count in counts.values():
p = count / total_size
if p > 0:
entropy -= p * math.log2(p)
return entropy
# 计算多数标签(平票返回0)
def get_majority_label(sample_indices, all_samples):
labels = [all_samples[i][1] for i in sample_indices]
counts = Counter(labels)
if counts[0] >= counts[1]:
return 0
return 1
# 递归构建决策树
def build_tree(sample_indices, used_features, all_samples, m):
# 终止条件1:标签全同
first_label = all_samples[sample_indices[0]][1]
if all(all_samples[i][1] == first_label for i in sample_indices):
return Node(is_leaf=True, prediction=first_label)
# 终止条件2:特征用尽
if len(used_features) == m:
return Node(is_leaf=True, prediction=get_majority_label(sample_indices, all_samples))
# 寻找最佳划分特征
best_gain = 0.0
best_feature_index = -1
parent_entropy = calculate_entropy(sample_indices, all_samples)
for i in range(m):
if i in used_features:
continue
left_indices = [idx for idx in sample_indices if all_samples[idx][0][i] == 0]
right_indices = [idx for idx in sample_indices if all_samples[idx][0][i] == 1]
total_size = len(sample_indices)
left_entropy = calculate_entropy(left_indices, all_samples)
right_entropy = calculate_entropy(right_indices, all_samples)
conditional_entropy = (len(left_indices) / total_size) * left_entropy + \
(len(right_indices) / total_size) * right_entropy
gain = parent_entropy - conditional_entropy
if gain > best_gain:
best_gain = gain
best_feature_index = i
# 终止条件3:无正增益
if best_gain <= 1e-9: # 浮点数比较
return Node(is_leaf=True, prediction=get_majority_label(sample_indices, all_samples))
# 递归划分
node = Node(feature_index=best_feature_index)
new_used_features = used_features | {best_feature_index}
left_indices = [idx for idx in sample_indices if all_samples[idx][0][best_feature_index] == 0]
right_indices = [idx for idx in sample_indices if all_samples[idx][0][best_feature_index] == 1]
majority_label = get_majority_label(sample_indices, all_samples)
node.left = Node(is_leaf=True, prediction=majority_label) if not left_indices else \
build_tree(left_indices, new_used_features, all_samples, m)
node.right = Node(is_leaf=True, prediction=majority_label) if not right_indices else \
build_tree(right_indices, new_used_features, all_samples, m)
return node
# 预测
def predict(node, features):
while not node.is_leaf:
if features[node.feature_index] == 0:
node = node.left
else:
node = node.right
return node.prediction
def solve():
n, m = map(int, input().split())
training_data = []
for _ in range(n):
line = list(map(int, input().split()))
training_data.append((line[:m], line[m]))
initial_indices = list(range(n))
used_features = set()
root = build_tree(initial_indices, used_features, training_data, m)
q = int(input())
for _ in range(q):
query_features = list(map(int, input().split()))
print(predict(root, query_features))
solve()
算法及复杂度
- 算法:ID3 决策树
- 时间复杂度:
,其中
是训练样本数,
是特征数,
是查询数。在构建树的每一层,我们需要遍历所有样本,并对每个未使用的特征计算信息增益。树的最大深度为
。因此,建树的粗略复杂度是
。预测一个样本需要从根走到叶,复杂度为
。
- 空间复杂度:
,主要用于存储训练数据。决策树本身的空间消耗相对较小。