基于决策树的QAM调制符号检测

题目分析

给定 个训练样本,每个样本包含两个特征 (复信号的实部和虚部)和一个标签 (16QAM 星座点编号)。要求:

  1. 计算训练集整体的基尼系数(Gini),保留 4 位小数。
  2. 用 CART 决策树对测试样本进行分类预测。

决策树的构建规则:

  • 分裂准则:基尼系数 ,选择使加权基尼最小的划分。
  • 候选阈值:仅
  • 分裂方向:特征值 阈值进入左子树, 阈值进入右子树。
  • 有效分裂:左右子集必须均非空,且加权基尼严格小于当前节点基尼。
  • 最大深度:5(根节点深度计为 1)。
  • 叶节点输出:多数类;若并列取数值较小的标签。

思路

模拟 CART 决策树构建

这道题没有复杂的算法技巧,关键是严格按照题目规则实现 CART 决策树。

基尼系数计算:统计各标签出现频率 ,代入 即可。

递归建树:对当前节点,遍历 2 个特征 7 个候选阈值共 14 种划分方案。对每种方案计算加权基尼 ,选最小的。如果没有任何方案能严格降低基尼,或已达最大深度,则标记为叶节点。

预测:从根节点出发,按特征值与阈值的比较走左或右子树,直到叶节点,输出该叶节点的多数类标签。

代码

import sys
from collections import Counter

def gini(labels):
    n = len(labels)
    if n == 0:
        return 0.0
    cnt = Counter(labels)
    return 1.0 - sum((c / n) ** 2 for c in cnt.values())

def majority_label(labels):
    cnt = Counter(labels)
    max_count = max(cnt.values())
    candidates = [k for k, v in cnt.items() if v == max_count]
    return min(candidates)

def build_tree(data, depth, max_depth):
    labels = [d[2] for d in data]
    if depth >= max_depth or len(data) == 0:
        return ('leaf', majority_label(labels))
    current_gini = gini(labels)
    if current_gini == 0.0:
        return ('leaf', majority_label(labels))

    best_wg = None
    best_feature = None
    best_threshold = None
    n = len(data)

    for feature in [0, 1]:
        for t in [-3, -2, -1, 0, 1, 2, 3]:
            left = [d for d in data if d[feature] < t]
            right = [d for d in data if d[feature] >= t]
            if not left or not right:
                continue
            wg = (len(left) / n) * gini([d[2] for d in left]) + \
                 (len(right) / n) * gini([d[2] for d in right])
            if wg >= current_gini:
                continue
            if best_wg is None or wg < best_wg:
                best_wg = wg
                best_feature = feature
                best_threshold = t

    if best_wg is None:
        return ('leaf', majority_label(labels))

    left_data = [d for d in data if d[best_feature] < best_threshold]
    right_data = [d for d in data if d[best_feature] >= best_threshold]
    return ('node', best_feature, best_threshold,
            build_tree(left_data, depth + 1, max_depth),
            build_tree(right_data, depth + 1, max_depth))

def predict(tree, x1, x2):
    if tree[0] == 'leaf':
        return tree[1]
    _, feature, threshold, left, right = tree
    if (x1 if feature == 0 else x2) < threshold:
        return predict(left, x1, x2)
    return predict(right, x1, x2)

def main():
    data = sys.stdin.read().split()
    idx = 0
    M = int(data[idx]); idx += 1
    samples = []
    for _ in range(M):
        x1 = float(data[idx]); x2 = float(data[idx+1]); y = int(data[idx+2])
        idx += 3
        samples.append((x1, x2, y))
    tx1 = float(data[idx]); tx2 = float(data[idx+1])

    print(f"{gini([d[2] for d in samples]):.4f}")
    tree = build_tree(samples, 1, 5)
    print(predict(tree, tx1, tx2))

if __name__ == '__main__':
    main()

复杂度分析

  • 时间复杂度,其中 为树的深度(最大 5),每层对 14 种划分方案各遍历所有样本。
  • 空间复杂度,存储训练样本和递归栈。