题目链接

ID3决策树训练

题目描述

给定一批包含 个二值特征和二分类标签(0 或 1)的样本,需要基于 ID3 决策树算法训练一个二叉分类器,并对 条查询样本进行预测。

规则与细节

  • 划分准则:信息增益 (Information Gain)。在当前节点,从尚未使用过的特征中,选择信息增益最大的进行划分。
  • 并列处理:若多个特征的信息增益相同,选择特征下标更小的那个。
  • 终止条件(生成叶子节点)
    1. 若当前节点包含的样本标签完全相同,则成为该标签的叶子节点。
    2. 若所有特征都已用于划分,或没有任何特征能带来正的信息增益,则成为“多数标签”的叶子节点(若标签平票,则为 0)。
  • 空分支处理:若某次划分导致一侧的样本集为空,该分支直接成为其父节点样本集的“多数标签”叶子节点。
  • 预测:从根节点开始,根据样本的特征值(0 或 1)向左或向右遍历,直到到达叶子节点,其值即为预测结果。

解题思路

本题要求从头实现一个 ID3 决策树。核心是递归地构建树结构,并在每个节点上选择最优的特征进行数据划分。

1. 核心概念:信息增益

ID3 算法通过信息增益来选择最优的划分特征。

  • 信息熵 (Entropy):衡量一个数据集纯度的指标。熵越小,数据集越纯(即标签越一致)。对于一个样本集 ,其信息熵 定义为: 其中 是类别 在样本集 中所占的比例。对于本题的二分类问题,公式为

  • 信息增益 (Information Gain):表示在得知特征 的信息后,数据集 的不确定性减少的程度。计算公式为: 对于本题的二值特征,公式简化为: 其中 分别是特征 的值为 0 和 1 的样本子集。信息增益越大的特征,划分效果越好。

2. 递归构建决策树

我们可以定义一个递归函数 build_tree(samples, used_features) 来构建决策树。

  • samples 是当前节点包含的训练样本。
  • used_features 是从根节点到当前节点的路径上已经使用过的特征集合。

递归的终止条件(即创建叶子节点)

  1. samples 中的所有样本标签相同:创建一个叶子节点,预测值为该标签。
  2. used_features 包含了所有特征:创建一个叶子节点,预测值为 samples 中的多数标签(平票为 0)。
  3. 遍历所有 未使用 的特征,计算信息增益。如果最大的信息增益 :创建一个叶子节点,预测值为 samples 中的多数标签(平票为 0)。

递归划分步骤

  1. 在未使用的特征中,找到信息增益最大的特征 (并列时选下标小的)。
  2. 创建一个内部节点,记录划分特征为
  3. samples 划分成两个子集: (特征 为 0) 和 (特征 为 1)。
  4. 处理空分支:如果 为空,则对应的子节点直接成为叶子节点,其预测值为父节点 samples 的多数标签。
  5. 如果子集非空,则递归调用 build_tree 创建子节点:
    • 左子节点:build_tree(S_0, used_features + {A*})
    • 右子节点:build_tree(S_1, used_features + {A*})

3. 预测

预测过程相对简单。对于一个待预测的样本,从树的根节点开始:

  1. 如果当前是叶子节点,直接返回其预测值。
  2. 如果当前是内部节点,获取其划分特征
  3. 检查待预测样本中特征 的值:
    • 如果值为 0,则移动到左子节点。
    • 如果值为 1,则移动到右子节点。
  4. 重复此过程,直到到达一个叶子节点。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <map>
#include <algorithm>

using namespace std;

// 样本数据结构
struct Sample {
    vector<int> features;
    int label;
};

// 决策树节点
struct Node {
    bool is_leaf = false;
    int prediction = -1;
    int feature_index = -1;
    Node* left = nullptr;
    Node* right = nullptr;
    
    ~Node() {
        delete left;
        delete right;
    }
};

// 计算给定样本索引的信息熵
double calculate_entropy(const vector<int>& sample_indices, const vector<Sample>& all_samples) {
    if (sample_indices.empty()) {
        return 0.0;
    }
    map<int, int> counts;
    for (int idx : sample_indices) {
        counts[all_samples[idx].label]++;
    }
    double entropy = 0.0;
    for (auto const& [label, count] : counts) {
        double p = (double)count / sample_indices.size();
        if (p > 0) {
            entropy -= p * log2(p);
        }
    }
    return entropy;
}

// 计算多数标签(平票返回0)
int get_majority_label(const vector<int>& sample_indices, const vector<Sample>& all_samples) {
    map<int, int> counts;
    for (int idx : sample_indices) {
        counts[all_samples[idx].label]++;
    }
    if (counts[0] >= counts[1]) {
        return 0;
    }
    return 1;
}

// 递归构建决策树
Node* build_tree(const vector<int>& sample_indices, vector<bool>& used_features, const vector<Sample>& all_samples, int m) {
    Node* node = new Node();

    // 终止条件1:标签全同
    int first_label = all_samples[sample_indices[0]].label;
    bool all_same = all_of(sample_indices.begin(), sample_indices.end(), 
        [&](int idx){ return all_samples[idx].label == first_label; });
    if (all_same) {
        node->is_leaf = true;
        node->prediction = first_label;
        return node;
    }

    // 终止条件2:特征用尽
    bool all_features_used = all_of(used_features.begin(), used_features.end(), [](bool v){ return v; });
    if (all_features_used) {
        node->is_leaf = true;
        node->prediction = get_majority_label(sample_indices, all_samples);
        return node;
    }

    // 寻找最佳划分特征
    double best_gain = 0.0;
    int best_feature_index = -1;
    double parent_entropy = calculate_entropy(sample_indices, all_samples);

    for (int i = 0; i < m; ++i) {
        if (used_features[i]) continue;
        
        vector<int> left_indices, right_indices;
        for (int idx : sample_indices) {
            if (all_samples[idx].features[i] == 0) left_indices.push_back(idx);
            else right_indices.push_back(idx);
        }

        double left_entropy = calculate_entropy(left_indices, all_samples);
        double right_entropy = calculate_entropy(right_indices, all_samples);
        double total_size = sample_indices.size();
        double conditional_entropy = (left_indices.size() / total_size) * left_entropy + (right_indices.size() / total_size) * right_entropy;
        double gain = parent_entropy - conditional_entropy;

        if (gain > best_gain) {
            best_gain = gain;
            best_feature_index = i;
        }
    }
    
    // 终止条件3:无正增益
    if (best_gain <= 0) {
        node->is_leaf = true;
        node->prediction = get_majority_label(sample_indices, all_samples);
        return node;
    }

    // 递归划分
    node->feature_index = best_feature_index;
    used_features[best_feature_index] = true;
    
    vector<int> left_indices, right_indices;
    for (int idx : sample_indices) {
        if (all_samples[idx].features[best_feature_index] == 0) left_indices.push_back(idx);
        else right_indices.push_back(idx);
    }

    int majority_label = get_majority_label(sample_indices, all_samples);

    if (left_indices.empty()) {
        node->left = new Node{true, majority_label};
    } else {
        node->left = build_tree(left_indices, used_features, all_samples, m);
    }

    if (right_indices.empty()) {
        node->right = new Node{true, majority_label};
    } else {
        node->right = build_tree(right_indices, used_features, all_samples, m);
    }
    
    used_features[best_feature_index] = false; // 回溯
    return node;
}

// 预测
int predict(const Node* node, const vector<int>& features) {
    while (!node->is_leaf) {
        if (features[node->feature_index] == 0) {
            node = node->left;
        } else {
            node = node->right;
        }
    }
    return node->prediction;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;
    vector<Sample> training_data(n);
    vector<int> initial_indices(n);
    for (int i = 0; i < n; ++i) {
        training_data[i].features.resize(m);
        for (int j = 0; j < m; ++j) {
            cin >> training_data[i].features[j];
        }
        cin >> training_data[i].label;
        initial_indices[i] = i;
    }

    vector<bool> used_features(m, false);
    Node* root = build_tree(initial_indices, used_features, training_data, m);

    int q;
    cin >> q;
    for (int i = 0; i < q; ++i) {
        vector<int> query_features(m);
        for (int j = 0; j < m; ++j) {
            cin >> query_features[j];
        }
        cout << predict(root, query_features) << "\n";
    }
    
    delete root;
    return 0;
}
import java.util.*;
import java.util.stream.Collectors;

public class Main {

    // 样本数据结构
    static class Sample {
        int[] features;
        int label;

        Sample(int m) {
            features = new int[m];
        }
    }

    // 决策树节点
    static class Node {
        boolean isLeaf = false;
        int prediction = -1;
        int featureIndex = -1;
        Node left = null;
        Node right = null;
    }

    // 计算给定样本索引的信息熵
    private static double calculateEntropy(List<Integer> sampleIndices, List<Sample> allSamples) {
        if (sampleIndices.isEmpty()) {
            return 0.0;
        }
        Map<Integer, Long> counts = sampleIndices.stream()
                .collect(Collectors.groupingBy(idx -> allSamples.get(idx).label, Collectors.counting()));
        
        double entropy = 0.0;
        for (long count : counts.values()) {
            double p = (double) count / sampleIndices.size();
            if (p > 0) {
                entropy -= p * (Math.log(p) / Math.log(2));
            }
        }
        return entropy;
    }

    // 计算多数标签(平票返回0)
    private static int getMajorityLabel(List<Integer> sampleIndices, List<Sample> allSamples) {
        long count0 = sampleIndices.stream().filter(idx -> allSamples.get(idx).label == 0).count();
        long count1 = sampleIndices.size() - count0;
        return (count0 >= count1) ? 0 : 1;
    }

    // 递归构建决策树
    private static Node buildTree(List<Integer> sampleIndices, boolean[] usedFeatures, List<Sample> allSamples, int m) {
        Node node = new Node();

        // 终止条件1:标签全同
        int firstLabel = allSamples.get(sampleIndices.get(0)).label;
        if (sampleIndices.stream().allMatch(idx -> allSamples.get(idx).label == firstLabel)) {
            node.isLeaf = true;
            node.prediction = firstLabel;
            return node;
        }

        // 终止条件2:特征用尽
        boolean allFeaturesUsed = true;
        for (boolean used : usedFeatures) {
            if (!used) {
                allFeaturesUsed = false;
                break;
            }
        }
        if (allFeaturesUsed) {
            node.isLeaf = true;
            node.prediction = getMajorityLabel(sampleIndices, allSamples);
            return node;
        }
        
        // 寻找最佳划分特征
        double bestGain = 0.0;
        int bestFeatureIndex = -1;
        double parentEntropy = calculateEntropy(sampleIndices, allSamples);

        for (int i = 0; i < m; i++) {
            if (usedFeatures[i]) continue;
            
            List<Integer> leftIndices = new ArrayList<>();
            List<Integer> rightIndices = new ArrayList<>();
            for (int idx : sampleIndices) {
                if (allSamples.get(idx).features[i] == 0) leftIndices.add(idx);
                else rightIndices.add(idx);
            }

            double leftEntropy = calculateEntropy(leftIndices, allSamples);
            double rightEntropy = calculateEntropy(rightIndices, allSamples);
            double totalSize = sampleIndices.size();
            double conditionalEntropy = (leftIndices.size() / totalSize) * leftEntropy + (rightIndices.size() / totalSize) * rightEntropy;
            double gain = parentEntropy - conditionalEntropy;

            if (gain > bestGain) {
                bestGain = gain;
                bestFeatureIndex = i;
            }
        }

        // 终止条件3:无正增益
        if (bestGain <= 1e-9) { // 使用一个小的 epsilon 来比较浮点数
            node.isLeaf = true;
            node.prediction = getMajorityLabel(sampleIndices, allSamples);
            return node;
        }

        // 递归划分
        node.featureIndex = bestFeatureIndex;
        usedFeatures[bestFeatureIndex] = true;

        List<Integer> leftIndices = new ArrayList<>();
        List<Integer> rightIndices = new ArrayList<>();
        for (int idx : sampleIndices) {
            if (allSamples.get(idx).features[bestFeatureIndex] == 0) leftIndices.add(idx);
            else rightIndices.add(idx);
        }

        int majorityLabel = getMajorityLabel(sampleIndices, allSamples);

        if (leftIndices.isEmpty()) {
            node.left = new Node();
            node.left.isLeaf = true;
            node.left.prediction = majorityLabel;
        } else {
            node.left = buildTree(leftIndices, usedFeatures, allSamples, m);
        }

        if (rightIndices.isEmpty()) {
            node.right = new Node();
            node.right.isLeaf = true;
            node.right.prediction = majorityLabel;
        } else {
            node.right = buildTree(rightIndices, usedFeatures, allSamples, m);
        }

        usedFeatures[bestFeatureIndex] = false; // 回溯
        return node;
    }

    // 预测
    private static int predict(Node node, int[] features) {
        while (!node.isLeaf) {
            if (features[node.featureIndex] == 0) {
                node = node.left;
            } else {
                node = node.right;
            }
        }
        return node.prediction;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        
        List<Sample> trainingData = new ArrayList<>();
        List<Integer> initialIndices = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            Sample sample = new Sample(m);
            for (int j = 0; j < m; j++) {
                sample.features[j] = sc.nextInt();
            }
            sample.label = sc.nextInt();
            trainingData.add(sample);
            initialIndices.add(i);
        }

        boolean[] usedFeatures = new boolean[m];
        Node root = buildTree(initialIndices, usedFeatures, trainingData, m);

        int q = sc.nextInt();
        for (int i = 0; i < q; i++) {
            int[] queryFeatures = new int[m];
            for (int j = 0; j < m; j++) {
                queryFeatures[j] = sc.nextInt();
            }
            System.out.println(predict(root, queryFeatures));
        }
    }
}
import math
from collections import Counter

# 决策树节点
class Node:
    def __init__(self, is_leaf=False, prediction=-1, feature_index=-1, left=None, right=None):
        self.is_leaf = is_leaf
        self.prediction = prediction
        self.feature_index = feature_index
        self.left = left
        self.right = right

# 计算信息熵
def calculate_entropy(sample_indices, all_samples):
    if not sample_indices:
        return 0
    labels = [all_samples[i][1] for i in sample_indices]
    counts = Counter(labels)
    entropy = 0
    total_size = len(sample_indices)
    for count in counts.values():
        p = count / total_size
        if p > 0:
            entropy -= p * math.log2(p)
    return entropy

# 计算多数标签(平票返回0)
def get_majority_label(sample_indices, all_samples):
    labels = [all_samples[i][1] for i in sample_indices]
    counts = Counter(labels)
    if counts[0] >= counts[1]:
        return 0
    return 1

# 递归构建决策树
def build_tree(sample_indices, used_features, all_samples, m):
    # 终止条件1:标签全同
    first_label = all_samples[sample_indices[0]][1]
    if all(all_samples[i][1] == first_label for i in sample_indices):
        return Node(is_leaf=True, prediction=first_label)

    # 终止条件2:特征用尽
    if len(used_features) == m:
        return Node(is_leaf=True, prediction=get_majority_label(sample_indices, all_samples))

    # 寻找最佳划分特征
    best_gain = 0.0
    best_feature_index = -1
    parent_entropy = calculate_entropy(sample_indices, all_samples)

    for i in range(m):
        if i in used_features:
            continue
        
        left_indices = [idx for idx in sample_indices if all_samples[idx][0][i] == 0]
        right_indices = [idx for idx in sample_indices if all_samples[idx][0][i] == 1]
        
        total_size = len(sample_indices)
        left_entropy = calculate_entropy(left_indices, all_samples)
        right_entropy = calculate_entropy(right_indices, all_samples)
        
        conditional_entropy = (len(left_indices) / total_size) * left_entropy + \
                              (len(right_indices) / total_size) * right_entropy
        gain = parent_entropy - conditional_entropy
        
        if gain > best_gain:
            best_gain = gain
            best_feature_index = i

    # 终止条件3:无正增益
    if best_gain <= 1e-9: # 浮点数比较
        return Node(is_leaf=True, prediction=get_majority_label(sample_indices, all_samples))

    # 递归划分
    node = Node(feature_index=best_feature_index)
    new_used_features = used_features | {best_feature_index}
    
    left_indices = [idx for idx in sample_indices if all_samples[idx][0][best_feature_index] == 0]
    right_indices = [idx for idx in sample_indices if all_samples[idx][0][best_feature_index] == 1]
    
    majority_label = get_majority_label(sample_indices, all_samples)

    node.left = Node(is_leaf=True, prediction=majority_label) if not left_indices else \
                build_tree(left_indices, new_used_features, all_samples, m)
    
    node.right = Node(is_leaf=True, prediction=majority_label) if not right_indices else \
                 build_tree(right_indices, new_used_features, all_samples, m)

    return node

# 预测
def predict(node, features):
    while not node.is_leaf:
        if features[node.feature_index] == 0:
            node = node.left
        else:
            node = node.right
    return node.prediction

def solve():
    n, m = map(int, input().split())
    training_data = []
    for _ in range(n):
        line = list(map(int, input().split()))
        training_data.append((line[:m], line[m]))
        
    initial_indices = list(range(n))
    used_features = set()
    root = build_tree(initial_indices, used_features, training_data, m)
    
    q = int(input())
    for _ in range(q):
        query_features = list(map(int, input().split()))
        print(predict(root, query_features))

solve()

算法及复杂度

  • 算法:ID3 决策树
  • 时间复杂度,其中 是训练样本数, 是特征数, 是查询数。在构建树的每一层,我们需要遍历所有样本,并对每个未使用的特征计算信息增益。树的最大深度为 。因此,建树的粗略复杂度是 。预测一个样本需要从根走到叶,复杂度为
  • 空间复杂度,主要用于存储训练数据。决策树本身的空间消耗相对较小。