题目链接

基于决策树的QAM调制符号检测

题目描述

用一小段带噪声的复数信号样本来训练一棵 CART 决策树,对16QAM符号进行分类判决。每个样本用两维实数特征表示:实部 与虚部 ;标签是整型类标(如0~15,对应16QAM的16个星座点)。

  • 划分标准:使用基尼系数(Gini)作为节点不纯度度量,选择加权 Gini 最小且左右子集均非空的划分。
  • 切分方式:只允许在特征 上,用固定阈值集合 中的某个阈值 进行二分;样本按“特征值 ” 分入左子树,否则进入右子树。
  • 叶子输出:叶子节点输出该节点内样本的多数类;若并列,取数值较小的标签,保证确定性。
  • 树深度:最大深度为 5(根深度计 1)。
  • 训练集整体Gini:按训练集中各标签频率一次性计算并输出。

请在读入训练样本后,先输出训练集整体 Gini(四舍五入保留 4 位小数),再用训练好的树对给定测试点 进行预测并输出其标签。

输入描述

  • 第 1 行:整数 ,表示训练样本数。
  • 第 2~M+1 行:每行三个数 为实数,y 为整型类标)。
  • 第 M+2 行:两个实数 ,表示测试样本的特征。

输出描述

  • 第 1 行:训练集整体 Gini,四舍五入保留 4 位小数。
  • 第 2 行:对测试样本的预测标签(整数)。

解题思路

本题要求我们手动实现一个简化的CART决策树的构建和预测过程。整个过程可以分为三个主要部分:计算初始Gini系数、递归构建决策树、使用构建好的树进行预测。

1. 整体Gini系数计算

这是题目的第一个输出。我们需要遍历整个训练数据集,统计每个标签(类别)出现的次数。Gini不纯度的计算公式为: 其中, 是数据集, 是类别总数, 是第 类样本占总样本数的比例。我们计算出这个值并按要求四舍五入保留4位小数即可。

2. 决策树的递归构建

决策树的生成是一个递归过程。我们可以定义一个函数,如 build_tree(samples, depth),它接收一个样本子集和当前深度作为输入,返回一个树节点。

  • 树的节点结构

    • 叶节点 (Leaf Node):表示一个决策的终点。需要存储该节点所代表的类别(即该节点中样本的多数类)。
    • 内部节点 (Internal Node):表示一个决策(分裂)点。需要存储分裂所依据的特征()、分裂的阈值,以及指向左、右两个子节点的引用。
  • 递归构建流程 build_tree(samples, depth):

    1. 终止条件(变为叶节点):

      • 当前节点包含的样本都属于同一类别(Gini系数为0)。
      • 树的深度达到了最大限制 5。
      • 找不到任何一个有效的、能降低Gini不纯度的分裂方式。
      • 如果满足以上任一条件,则当前节点成为叶节点。计算该节点samples中的多数类作为其预测值(如果票数并列,取数值较小的标签),然后返回该叶节点。
    2. 寻找最佳分裂点:

      • 遍历所有可能的分裂维度(特征)。
      • 对于每个维度,遍历所有可能的分裂阈值
      • 对于每一个(特征, 阈值)组合,将当前节点的samples分成两个子集:左子集(特征值 )和右子集(特征值 )。
      • 有效性检查:如果分裂导致任一子集为空,则该分裂无效,跳过。
      • 计算加权Gini:对于有效的活性,计算分裂后的加权Gini不纯度:
      • 在所有有效的分裂中,找到那个使 最小的分裂点(特征+阈值)。
    3. 递归生成子树:

      • 如果找到的最佳分裂点的 严格小于当前节点的Gini不纯度,说明这是一个有益的分裂。
      • 创建一个内部节点,记录下最佳分裂的特征和阈值。
      • 使用分裂出的左、右两个样本子集,分别递归调用 build_tree(left_samples, depth + 1)build_tree(right_samples, depth + 1),将其返回的节点作为当前节点的左、右子节点。
      • 返回当前创建的内部节点。

3. 对测试样本进行预测

这是一个简单的树遍历过程。从根节点开始:

  1. 判断当前节点是内部节点还是叶节点。
  2. 如果是叶节点,则其存储的预测值就是最终结果。
  3. 如果是内部节点,根据其存储的分裂特征和阈值,判断测试样本应该进入左子树还是右子树。
  4. 移动到对应的子节点,重复步骤1。

代码

#include <iostream>
#include <vector>
#include <map>
#include <cmath>
#include <iomanip>
#include <algorithm>
#include <limits>

using namespace std;

// 样本结构体
struct Sample {
    double x1, x2;
    int y;
};

// 决策树节点结构体
struct Node {
    bool is_leaf;
    int prediction; // 叶节点存储预测类别
    int feature_idx; // 内部节点存储分裂特征索引 (0 for x1, 1 for x2)
    double threshold; // 内部节点存储分裂阈值
    Node *left = nullptr, *right = nullptr;

    // 叶节点构造函数
    Node(int pred) : is_leaf(true), prediction(pred) {}
    // 内部节点构造函数
    Node(int feat_idx, double thresh) : is_leaf(false), feature_idx(feat_idx), threshold(thresh) {}
};

// 计算给定样本集的 Gini 不纯度
double calculate_gini(const vector<Sample>& samples) {
    if (samples.empty()) {
        return 0.0;
    }
    map<int, int> counts;
    for (const auto& s : samples) {
        counts[s.y]++;
    }
    double gini = 1.0;
    for (const auto& pair : counts) {
        double p = (double)pair.second / samples.size();
        gini -= p * p;
    }
    return gini;
}

// 获取样本集中的多数类,如果票数相同则取类别值较小的
int get_majority_class(const vector<Sample>& samples) {
    map<int, int> counts;
    for (const auto& s : samples) {
        counts[s.y]++;
    }
    int max_count = 0;
    int majority_class = -1;
    // C++ map 默认按 key 排序,所以遍历时自然满足“取数值较小的标签”
    for (const auto& pair : counts) {
        if (pair.second > max_count) {
            max_count = pair.second;
            majority_class = pair.first;
        } else if (pair.second == max_count) {
            majority_class = min(majority_class, pair.first);
        }
    }
    return majority_class;
}

// 递归构建决策树
Node* build_tree(const vector<Sample>& samples, int depth) {
    // 终止条件:样本为空、达到最大深度、或 Gini 为 0 (已纯净)
    if (samples.empty() || depth > 5 || calculate_gini(samples) == 0) {
        return new Node(get_majority_class(samples));
    }

    double current_gini = calculate_gini(samples);
    double best_gini = current_gini;
    int best_feature = -1;
    double best_threshold = 0;
    vector<Sample> best_left, best_right;

    int features[] = {0, 1}; // 特征索引 x1, x2
    double thresholds[] = {-3, -2, -1, 0, 1, 2, 3}; // 预设阈值

    // 遍历所有特征和阈值,寻找最佳分裂点
    for (int feature_idx : features) {
        for (double t : thresholds) {
            vector<Sample> left_samples, right_samples;
            for (const auto& s : samples) {
                double val = (feature_idx == 0) ? s.x1 : s.x2;
                if (val <= t) {
                    left_samples.push_back(s);
                } else {
                    right_samples.push_back(s);
                }
            }
            // 分裂必须保证左右子集都非空
            if (left_samples.empty() || right_samples.empty()) {
                continue;
            }
            
            // 计算加权 Gini 系数
            double weighted_gini = (double)left_samples.size() / samples.size() * calculate_gini(left_samples) +
                                   (double)right_samples.size() / samples.size() * calculate_gini(right_samples);
            
            // 寻找能使 Gini 最小化的分裂
            if (weighted_gini < best_gini) {
                best_gini = weighted_gini;
                best_feature = feature_idx;
                best_threshold = t;
                best_left = left_samples;
                best_right = right_samples;
            }
        }
    }

    // 如果找不到更好的分裂,则当前节点变为叶节点
    if (best_feature == -1) {
        return new Node(get_majority_class(samples));
    }

    // 创建内部节点,并递归构建左右子树
    Node* node = new Node(best_feature, best_threshold);
    node->left = build_tree(best_left, depth + 1);
    node->right = build_tree(best_right, depth + 1);
    return node;
}

// 使用训练好的决策树进行预测
int predict(Node* node, double tx1, double tx2) {
    while (!node->is_leaf) {
        double val = (node->feature_idx == 0) ? tx1 : tx2;
        if (val <= node->threshold) {
            node = node->left;
        } else {
            node = node->right;
        }
    }
    return node->prediction;
}

int main() {
    int m;
    cin >> m;
    vector<Sample> all_samples(m);
    for (int i = 0; i < m; ++i) {
        cin >> all_samples[i].x1 >> all_samples[i].x2 >> all_samples[i].y;
    }
    double tx1, tx2;
    cin >> tx1 >> tx2;

    // 输出整体 Gini
    cout << fixed << setprecision(4) << calculate_gini(all_samples) << endl;

    // 构建决策树
    Node* root = build_tree(all_samples, 1);
    // 输出预测结果
    cout << predict(root, tx1, tx2) << endl;

    return 0;
}
import java.util.*;
import java.io.*;

public class Main {
    // 样本类
    static class Sample {
        double x1, x2;
        int y;
        Sample(double x1, double x2, int y) {
            this.x1 = x1; this.x2 = x2; this.y = y;
        }
    }

    // 决策树节点类
    static class Node {
        boolean isLeaf;
        int prediction; // 叶节点:预测类别
        int featureIdx; // 内部节点:分裂特征
        double threshold; // 内部节点:分裂阈值
        Node left, right;

        // 叶节点构造器
        Node(int prediction) {
            this.isLeaf = true; this.prediction = prediction;
        }
        // 内部节点构造器
        Node(int featureIdx, double threshold) {
            this.isLeaf = false; this.featureIdx = featureIdx; this.threshold = threshold;
        }
    }

    // 计算 Gini 不纯度
    static double calculateGini(List<Sample> samples) {
        if (samples.isEmpty()) return 0.0;
        Map<Integer, Integer> counts = new HashMap<>();
        for (Sample s : samples) {
            counts.put(s.y, counts.getOrDefault(s.y, 0) + 1);
        }
        double gini = 1.0;
        for (int count : counts.values()) {
            double p = (double) count / samples.size();
            gini -= p * p;
        }
        return gini;
    }

    // 获取多数类,票数相同时取值小的
    static int getMajorityClass(List<Sample> samples) {
        if (samples.isEmpty()) return -1;
        // TreeMap 自动按 key (标签值) 排序,便于处理平票情况
        Map<Integer, Integer> counts = new TreeMap<>();
        for (Sample s : samples) {
            counts.put(s.y, counts.getOrDefault(s.y, 0) + 1);
        }
        int maxCount = 0;
        int majorityClass = -1;
        // 由于 TreeMap 的有序性,我们能保证在票数相同时,第一次出现的即是值最小的
        for (Map.Entry<Integer, Integer> entry : counts.entrySet()) {
            if (entry.getValue() > maxCount) {
                maxCount = entry.getValue();
                majorityClass = entry.getKey();
            }
        }
        return majorityClass;
    }

    // 递归构建决策树
    static Node buildTree(List<Sample> samples, int depth) {
        // 终止条件
        if (samples.isEmpty() || depth > 5 || calculateGini(samples) == 0) {
            return new Node(getMajorityClass(samples));
        }

        double currentGini = calculateGini(samples);
        double bestGini = currentGini;
        int bestFeature = -1;
        double bestThreshold = 0;
        List<Sample> bestLeft = null, bestRight = null;

        int[] features = {0, 1}; // 特征索引 x1, x2
        double[] thresholds = {-3, -2, -1, 0, 1, 2, 3}; // 预设阈值

        // 寻找最佳分裂点
        for (int featureIdx : features) {
            for (double t : thresholds) {
                List<Sample> leftSamples = new ArrayList<>(), rightSamples = new ArrayList<>();
                for (Sample s : samples) {
                    double val = (featureIdx == 0) ? s.x1 : s.x2;
                    if (val <= t) leftSamples.add(s);
                    else rightSamples.add(s);
                }
                // 保证左右子集非空
                if (leftSamples.isEmpty() || rightSamples.isEmpty()) continue;
                
                // 计算加权 Gini
                double weightedGini = (double)leftSamples.size() / samples.size() * calculateGini(leftSamples) +
                                      (double)rightSamples.size() / samples.size() * calculateGini(rightSamples);
                
                // 更新最佳分裂
                if (weightedGini < bestGini) {
                    bestGini = weightedGini;
                    bestFeature = featureIdx;
                    bestThreshold = t;
                    bestLeft = leftSamples;
                    bestRight = rightSamples;
                }
            }
        }

        // 找不到更好的分裂,变为叶节点
        if (bestFeature == -1) {
            return new Node(getMajorityClass(samples));
        }

        // 递归构建子树
        Node node = new Node(bestFeature, bestThreshold);
        node.left = buildTree(bestLeft, depth + 1);
        node.right = buildTree(bestRight, depth + 1);
        return node;
    }
    
    // 预测
    static int predict(Node node, double tx1, double tx2) {
        while (!node.isLeaf) {
            double val = (node.featureIdx == 0) ? tx1 : tx2;
            node = (val <= node.threshold) ? node.left : node.right;
        }
        return node.prediction;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int m = sc.nextInt();
        List<Sample> allSamples = new ArrayList<>();
        for (int i = 0; i < m; i++) {
            allSamples.add(new Sample(sc.nextDouble(), sc.nextDouble(), sc.nextInt()));
        }
        double tx1 = sc.nextDouble();
        double tx2 = sc.nextDouble();

        // 输出整体 Gini
        System.out.printf("%.4f\n", calculateGini(allSamples));

        // 构建并预测
        Node root = buildTree(allSamples, 1);
        System.out.println(predict(root, tx1, tx2));
    }
}
import sys
from collections import Counter

# 决策树节点类
class Node:
    def __init__(self, prediction=None, feature_idx=None, threshold=None, left=None, right=None):
        self.prediction = prediction      # 叶节点:预测类别
        self.feature_idx = feature_idx    # 内部节点:分裂特征
        self.threshold = threshold      # 内部节点:分裂阈值
        self.left = left                # 左子树
        self.right = right              # 右子树

    @property
    def is_leaf(self):
        # 通过是否有预测值来判断是否为叶节点
        return self.prediction is not None

# 计算 Gini 不纯度
def calculate_gini(samples):
    if not samples:
        return 0.0
    labels = [s[2] for s in samples]
    counts = Counter(labels)
    total = len(samples)
    gini = 1.0
    for label in counts:
        p = counts[label] / total
        gini -= p * p
    return gini

# 获取多数类,票数相同时取值小的
def get_majority_class(samples):
    if not samples:
        return -1
    labels = [s[2] for s in samples]
    counts = Counter(labels)
    # 寻找票数最多的类别
    max_count = 0
    majority_class = -1
    # 对类别排序,以确保在票数相同时,优先选择数值较小的类别
    for label in sorted(counts.keys()):
        if counts[label] > max_count:
            max_count = counts[label]
            majority_class = label
    return majority_class

# 递归构建决策树
def build_tree(samples, depth):
    # 终止条件:样本为空、达到最大深度、或已纯净
    if not samples or depth > 5 or calculate_gini(samples) == 0:
        return Node(prediction=get_majority_class(samples))

    current_gini = calculate_gini(samples)
    best_gini = current_gini
    best_split = None
    best_subsets = None

    features = [0, 1] # 特征 x1, x2
    thresholds = [-3, -2, -1, 0, 1, 2, 3] # 阈值集合

    # 遍历寻找最佳分裂点
    for feature_idx in features:
        for t in thresholds:
            left_samples = [s for s in samples if s[feature_idx] <= t]
            right_samples = [s for s in samples if s[feature_idx] > t]

            # 左右子集必须非空
            if not left_samples or not right_samples:
                continue
            
            # 计算加权 Gini
            weighted_gini = (len(left_samples) / len(samples)) * calculate_gini(left_samples) + \
                            (len(right_samples) / len(samples)) * calculate_gini(right_samples)

            # 更新最佳分裂
            if weighted_gini < best_gini:
                best_gini = weighted_gini
                best_split = (feature_idx, t)
                best_subsets = (left_samples, right_samples)

    # 如果 Gini 没有降低,则变为叶节点
    if best_split is None:
        return Node(prediction=get_majority_class(samples))

    # 递归构建子树
    feature_idx, threshold = best_split
    left_samples, right_samples = best_subsets
    left_child = build_tree(left_samples, depth + 1)
    right_child = build_tree(right_samples, depth + 1)
    return Node(feature_idx=feature_idx, threshold=threshold, left=left_child, right=right_child)

# 预测
def predict(node, test_sample):
    # 沿着树向下遍历,直到叶节点
    while not node.is_leaf:
        if test_sample[node.feature_idx] <= node.threshold:
            node = node.left
        else:
            node = node.right
    return node.prediction

def solve():
    m = int(input())
    all_samples = []
    for i in range(m):
        x1, x2, y = map(float, input().split())
        all_samples.append((x1, x2, y))
    
    tx1, tx2 = map(float, input().split())

    # 输出整体 Gini
    print(f"{calculate_gini(all_samples):.4f}")

    # 构建并预测
    root = build_tree(all_samples, 1)
    prediction = predict(root, (tx1, tx2))
    print(int(prediction))

solve()

算法及复杂度

  • 算法: 决策树(模拟)
  • 时间复杂度: ,其中 是训练样本数, 是最大深度(5), 是特征数(2), 是阈值数(7)。由于 均为小常数,复杂度可视为
  • 空间复杂度: ,主要用于存储训练样本数据和递归调用栈。