题目链接
题目描述
用一小段带噪声的复数信号样本来训练一棵 CART 决策树,对16QAM符号进行分类判决。每个样本用两维实数特征表示:实部 与虚部
;标签是整型类标(如0~15,对应16QAM的16个星座点)。
- 划分标准:使用基尼系数(Gini)作为节点不纯度度量,选择加权 Gini 最小且左右子集均非空的划分。
- 切分方式:只允许在特征
或
上,用固定阈值集合
中的某个阈值
进行二分;样本按“特征值
” 分入左子树,否则进入右子树。
- 叶子输出:叶子节点输出该节点内样本的多数类;若并列,取数值较小的标签,保证确定性。
- 树深度:最大深度为 5(根深度计 1)。
- 训练集整体Gini:按训练集中各标签频率一次性计算并输出。
请在读入训练样本后,先输出训练集整体 Gini(四舍五入保留 4 位小数),再用训练好的树对给定测试点 进行预测并输出其标签。
输入描述
- 第 1 行:整数
,表示训练样本数。
- 第 2~M+1 行:每行三个数
(
为实数,y 为整型类标)。
- 第 M+2 行:两个实数
,表示测试样本的特征。
输出描述
- 第 1 行:训练集整体 Gini,四舍五入保留 4 位小数。
- 第 2 行:对测试样本的预测标签(整数)。
解题思路
本题要求我们手动实现一个简化的CART决策树的构建和预测过程。整个过程可以分为三个主要部分:计算初始Gini系数、递归构建决策树、使用构建好的树进行预测。
1. 整体Gini系数计算
这是题目的第一个输出。我们需要遍历整个训练数据集,统计每个标签(类别)出现的次数。Gini不纯度的计算公式为:
其中,
是数据集,
是类别总数,
是第
类样本占总样本数的比例。我们计算出这个值并按要求四舍五入保留4位小数即可。
2. 决策树的递归构建
决策树的生成是一个递归过程。我们可以定义一个函数,如 build_tree(samples, depth),它接收一个样本子集和当前深度作为输入,返回一个树节点。
-
树的节点结构:
- 叶节点 (Leaf Node):表示一个决策的终点。需要存储该节点所代表的类别(即该节点中样本的多数类)。
- 内部节点 (Internal Node):表示一个决策(分裂)点。需要存储分裂所依据的特征(
或
)、分裂的阈值,以及指向左、右两个子节点的引用。
-
递归构建流程
build_tree(samples, depth):-
终止条件(变为叶节点):
- 当前节点包含的样本都属于同一类别(Gini系数为0)。
- 树的深度达到了最大限制 5。
- 找不到任何一个有效的、能降低Gini不纯度的分裂方式。
- 如果满足以上任一条件,则当前节点成为叶节点。计算该节点
samples中的多数类作为其预测值(如果票数并列,取数值较小的标签),然后返回该叶节点。
-
寻找最佳分裂点:
- 遍历所有可能的分裂维度(特征
和
)。
- 对于每个维度,遍历所有可能的分裂阈值
。
- 对于每一个(特征, 阈值)组合,将当前节点的
samples分成两个子集:左子集(特征值)和右子集(特征值
)。
- 有效性检查:如果分裂导致任一子集为空,则该分裂无效,跳过。
- 计算加权Gini:对于有效的活性,计算分裂后的加权Gini不纯度:
- 在所有有效的分裂中,找到那个使
最小的分裂点(特征+阈值)。
- 遍历所有可能的分裂维度(特征
-
递归生成子树:
- 如果找到的最佳分裂点的
严格小于当前节点的Gini不纯度,说明这是一个有益的分裂。
- 创建一个内部节点,记录下最佳分裂的特征和阈值。
- 使用分裂出的左、右两个样本子集,分别递归调用
build_tree(left_samples, depth + 1)和build_tree(right_samples, depth + 1),将其返回的节点作为当前节点的左、右子节点。 - 返回当前创建的内部节点。
- 如果找到的最佳分裂点的
-
3. 对测试样本进行预测
这是一个简单的树遍历过程。从根节点开始:
- 判断当前节点是内部节点还是叶节点。
- 如果是叶节点,则其存储的预测值就是最终结果。
- 如果是内部节点,根据其存储的分裂特征和阈值,判断测试样本应该进入左子树还是右子树。
- 移动到对应的子节点,重复步骤1。
代码
#include <iostream>
#include <vector>
#include <map>
#include <cmath>
#include <iomanip>
#include <algorithm>
#include <limits>
using namespace std;
// 样本结构体
struct Sample {
double x1, x2;
int y;
};
// 决策树节点结构体
struct Node {
bool is_leaf;
int prediction; // 叶节点存储预测类别
int feature_idx; // 内部节点存储分裂特征索引 (0 for x1, 1 for x2)
double threshold; // 内部节点存储分裂阈值
Node *left = nullptr, *right = nullptr;
// 叶节点构造函数
Node(int pred) : is_leaf(true), prediction(pred) {}
// 内部节点构造函数
Node(int feat_idx, double thresh) : is_leaf(false), feature_idx(feat_idx), threshold(thresh) {}
};
// 计算给定样本集的 Gini 不纯度
double calculate_gini(const vector<Sample>& samples) {
if (samples.empty()) {
return 0.0;
}
map<int, int> counts;
for (const auto& s : samples) {
counts[s.y]++;
}
double gini = 1.0;
for (const auto& pair : counts) {
double p = (double)pair.second / samples.size();
gini -= p * p;
}
return gini;
}
// 获取样本集中的多数类,如果票数相同则取类别值较小的
int get_majority_class(const vector<Sample>& samples) {
map<int, int> counts;
for (const auto& s : samples) {
counts[s.y]++;
}
int max_count = 0;
int majority_class = -1;
// C++ map 默认按 key 排序,所以遍历时自然满足“取数值较小的标签”
for (const auto& pair : counts) {
if (pair.second > max_count) {
max_count = pair.second;
majority_class = pair.first;
} else if (pair.second == max_count) {
majority_class = min(majority_class, pair.first);
}
}
return majority_class;
}
// 递归构建决策树
Node* build_tree(const vector<Sample>& samples, int depth) {
// 终止条件:样本为空、达到最大深度、或 Gini 为 0 (已纯净)
if (samples.empty() || depth > 5 || calculate_gini(samples) == 0) {
return new Node(get_majority_class(samples));
}
double current_gini = calculate_gini(samples);
double best_gini = current_gini;
int best_feature = -1;
double best_threshold = 0;
vector<Sample> best_left, best_right;
int features[] = {0, 1}; // 特征索引 x1, x2
double thresholds[] = {-3, -2, -1, 0, 1, 2, 3}; // 预设阈值
// 遍历所有特征和阈值,寻找最佳分裂点
for (int feature_idx : features) {
for (double t : thresholds) {
vector<Sample> left_samples, right_samples;
for (const auto& s : samples) {
double val = (feature_idx == 0) ? s.x1 : s.x2;
if (val <= t) {
left_samples.push_back(s);
} else {
right_samples.push_back(s);
}
}
// 分裂必须保证左右子集都非空
if (left_samples.empty() || right_samples.empty()) {
continue;
}
// 计算加权 Gini 系数
double weighted_gini = (double)left_samples.size() / samples.size() * calculate_gini(left_samples) +
(double)right_samples.size() / samples.size() * calculate_gini(right_samples);
// 寻找能使 Gini 最小化的分裂
if (weighted_gini < best_gini) {
best_gini = weighted_gini;
best_feature = feature_idx;
best_threshold = t;
best_left = left_samples;
best_right = right_samples;
}
}
}
// 如果找不到更好的分裂,则当前节点变为叶节点
if (best_feature == -1) {
return new Node(get_majority_class(samples));
}
// 创建内部节点,并递归构建左右子树
Node* node = new Node(best_feature, best_threshold);
node->left = build_tree(best_left, depth + 1);
node->right = build_tree(best_right, depth + 1);
return node;
}
// 使用训练好的决策树进行预测
int predict(Node* node, double tx1, double tx2) {
while (!node->is_leaf) {
double val = (node->feature_idx == 0) ? tx1 : tx2;
if (val <= node->threshold) {
node = node->left;
} else {
node = node->right;
}
}
return node->prediction;
}
int main() {
int m;
cin >> m;
vector<Sample> all_samples(m);
for (int i = 0; i < m; ++i) {
cin >> all_samples[i].x1 >> all_samples[i].x2 >> all_samples[i].y;
}
double tx1, tx2;
cin >> tx1 >> tx2;
// 输出整体 Gini
cout << fixed << setprecision(4) << calculate_gini(all_samples) << endl;
// 构建决策树
Node* root = build_tree(all_samples, 1);
// 输出预测结果
cout << predict(root, tx1, tx2) << endl;
return 0;
}
import java.util.*;
import java.io.*;
public class Main {
// 样本类
static class Sample {
double x1, x2;
int y;
Sample(double x1, double x2, int y) {
this.x1 = x1; this.x2 = x2; this.y = y;
}
}
// 决策树节点类
static class Node {
boolean isLeaf;
int prediction; // 叶节点:预测类别
int featureIdx; // 内部节点:分裂特征
double threshold; // 内部节点:分裂阈值
Node left, right;
// 叶节点构造器
Node(int prediction) {
this.isLeaf = true; this.prediction = prediction;
}
// 内部节点构造器
Node(int featureIdx, double threshold) {
this.isLeaf = false; this.featureIdx = featureIdx; this.threshold = threshold;
}
}
// 计算 Gini 不纯度
static double calculateGini(List<Sample> samples) {
if (samples.isEmpty()) return 0.0;
Map<Integer, Integer> counts = new HashMap<>();
for (Sample s : samples) {
counts.put(s.y, counts.getOrDefault(s.y, 0) + 1);
}
double gini = 1.0;
for (int count : counts.values()) {
double p = (double) count / samples.size();
gini -= p * p;
}
return gini;
}
// 获取多数类,票数相同时取值小的
static int getMajorityClass(List<Sample> samples) {
if (samples.isEmpty()) return -1;
// TreeMap 自动按 key (标签值) 排序,便于处理平票情况
Map<Integer, Integer> counts = new TreeMap<>();
for (Sample s : samples) {
counts.put(s.y, counts.getOrDefault(s.y, 0) + 1);
}
int maxCount = 0;
int majorityClass = -1;
// 由于 TreeMap 的有序性,我们能保证在票数相同时,第一次出现的即是值最小的
for (Map.Entry<Integer, Integer> entry : counts.entrySet()) {
if (entry.getValue() > maxCount) {
maxCount = entry.getValue();
majorityClass = entry.getKey();
}
}
return majorityClass;
}
// 递归构建决策树
static Node buildTree(List<Sample> samples, int depth) {
// 终止条件
if (samples.isEmpty() || depth > 5 || calculateGini(samples) == 0) {
return new Node(getMajorityClass(samples));
}
double currentGini = calculateGini(samples);
double bestGini = currentGini;
int bestFeature = -1;
double bestThreshold = 0;
List<Sample> bestLeft = null, bestRight = null;
int[] features = {0, 1}; // 特征索引 x1, x2
double[] thresholds = {-3, -2, -1, 0, 1, 2, 3}; // 预设阈值
// 寻找最佳分裂点
for (int featureIdx : features) {
for (double t : thresholds) {
List<Sample> leftSamples = new ArrayList<>(), rightSamples = new ArrayList<>();
for (Sample s : samples) {
double val = (featureIdx == 0) ? s.x1 : s.x2;
if (val <= t) leftSamples.add(s);
else rightSamples.add(s);
}
// 保证左右子集非空
if (leftSamples.isEmpty() || rightSamples.isEmpty()) continue;
// 计算加权 Gini
double weightedGini = (double)leftSamples.size() / samples.size() * calculateGini(leftSamples) +
(double)rightSamples.size() / samples.size() * calculateGini(rightSamples);
// 更新最佳分裂
if (weightedGini < bestGini) {
bestGini = weightedGini;
bestFeature = featureIdx;
bestThreshold = t;
bestLeft = leftSamples;
bestRight = rightSamples;
}
}
}
// 找不到更好的分裂,变为叶节点
if (bestFeature == -1) {
return new Node(getMajorityClass(samples));
}
// 递归构建子树
Node node = new Node(bestFeature, bestThreshold);
node.left = buildTree(bestLeft, depth + 1);
node.right = buildTree(bestRight, depth + 1);
return node;
}
// 预测
static int predict(Node node, double tx1, double tx2) {
while (!node.isLeaf) {
double val = (node.featureIdx == 0) ? tx1 : tx2;
node = (val <= node.threshold) ? node.left : node.right;
}
return node.prediction;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int m = sc.nextInt();
List<Sample> allSamples = new ArrayList<>();
for (int i = 0; i < m; i++) {
allSamples.add(new Sample(sc.nextDouble(), sc.nextDouble(), sc.nextInt()));
}
double tx1 = sc.nextDouble();
double tx2 = sc.nextDouble();
// 输出整体 Gini
System.out.printf("%.4f\n", calculateGini(allSamples));
// 构建并预测
Node root = buildTree(allSamples, 1);
System.out.println(predict(root, tx1, tx2));
}
}
import sys
from collections import Counter
# 决策树节点类
class Node:
def __init__(self, prediction=None, feature_idx=None, threshold=None, left=None, right=None):
self.prediction = prediction # 叶节点:预测类别
self.feature_idx = feature_idx # 内部节点:分裂特征
self.threshold = threshold # 内部节点:分裂阈值
self.left = left # 左子树
self.right = right # 右子树
@property
def is_leaf(self):
# 通过是否有预测值来判断是否为叶节点
return self.prediction is not None
# 计算 Gini 不纯度
def calculate_gini(samples):
if not samples:
return 0.0
labels = [s[2] for s in samples]
counts = Counter(labels)
total = len(samples)
gini = 1.0
for label in counts:
p = counts[label] / total
gini -= p * p
return gini
# 获取多数类,票数相同时取值小的
def get_majority_class(samples):
if not samples:
return -1
labels = [s[2] for s in samples]
counts = Counter(labels)
# 寻找票数最多的类别
max_count = 0
majority_class = -1
# 对类别排序,以确保在票数相同时,优先选择数值较小的类别
for label in sorted(counts.keys()):
if counts[label] > max_count:
max_count = counts[label]
majority_class = label
return majority_class
# 递归构建决策树
def build_tree(samples, depth):
# 终止条件:样本为空、达到最大深度、或已纯净
if not samples or depth > 5 or calculate_gini(samples) == 0:
return Node(prediction=get_majority_class(samples))
current_gini = calculate_gini(samples)
best_gini = current_gini
best_split = None
best_subsets = None
features = [0, 1] # 特征 x1, x2
thresholds = [-3, -2, -1, 0, 1, 2, 3] # 阈值集合
# 遍历寻找最佳分裂点
for feature_idx in features:
for t in thresholds:
left_samples = [s for s in samples if s[feature_idx] <= t]
right_samples = [s for s in samples if s[feature_idx] > t]
# 左右子集必须非空
if not left_samples or not right_samples:
continue
# 计算加权 Gini
weighted_gini = (len(left_samples) / len(samples)) * calculate_gini(left_samples) + \
(len(right_samples) / len(samples)) * calculate_gini(right_samples)
# 更新最佳分裂
if weighted_gini < best_gini:
best_gini = weighted_gini
best_split = (feature_idx, t)
best_subsets = (left_samples, right_samples)
# 如果 Gini 没有降低,则变为叶节点
if best_split is None:
return Node(prediction=get_majority_class(samples))
# 递归构建子树
feature_idx, threshold = best_split
left_samples, right_samples = best_subsets
left_child = build_tree(left_samples, depth + 1)
right_child = build_tree(right_samples, depth + 1)
return Node(feature_idx=feature_idx, threshold=threshold, left=left_child, right=right_child)
# 预测
def predict(node, test_sample):
# 沿着树向下遍历,直到叶节点
while not node.is_leaf:
if test_sample[node.feature_idx] <= node.threshold:
node = node.left
else:
node = node.right
return node.prediction
def solve():
m = int(input())
all_samples = []
for i in range(m):
x1, x2, y = map(float, input().split())
all_samples.append((x1, x2, y))
tx1, tx2 = map(float, input().split())
# 输出整体 Gini
print(f"{calculate_gini(all_samples):.4f}")
# 构建并预测
root = build_tree(all_samples, 1)
prediction = predict(root, (tx1, tx2))
print(int(prediction))
solve()
算法及复杂度
- 算法: 决策树(模拟)
- 时间复杂度:
,其中
是训练样本数,
是最大深度(5),
是特征数(2),
是阈值数(7)。由于
均为小常数,复杂度可视为
。
- 空间复杂度:
,主要用于存储训练样本数据和递归调用栈。

京公网安备 11010502036488号