基于决策树的QAM调制符号检测
题目分析
给定 个训练样本,每个样本包含两个特征
(复信号的实部和虚部)和一个标签
(16QAM 星座点编号)。要求:
- 计算训练集整体的基尼系数(Gini),保留 4 位小数。
- 用 CART 决策树对测试样本进行分类预测。
决策树的构建规则:
- 分裂准则:基尼系数
,选择使加权基尼最小的划分。
- 候选阈值:仅
。
- 分裂方向:特征值
阈值进入左子树,
阈值进入右子树。
- 有效分裂:左右子集必须均非空,且加权基尼严格小于当前节点基尼。
- 最大深度:5(根节点深度计为 1)。
- 叶节点输出:多数类;若并列取数值较小的标签。
思路
模拟 CART 决策树构建
这道题没有复杂的算法技巧,关键是严格按照题目规则实现 CART 决策树。
基尼系数计算:统计各标签出现频率 ,代入
即可。
递归建树:对当前节点,遍历 2 个特征 7 个候选阈值共 14 种划分方案。对每种方案计算加权基尼
,选最小的。如果没有任何方案能严格降低基尼,或已达最大深度,则标记为叶节点。
预测:从根节点出发,按特征值与阈值的比较走左或右子树,直到叶节点,输出该叶节点的多数类标签。
代码
import sys
from collections import Counter
def gini(labels):
n = len(labels)
if n == 0:
return 0.0
cnt = Counter(labels)
return 1.0 - sum((c / n) ** 2 for c in cnt.values())
def majority_label(labels):
cnt = Counter(labels)
max_count = max(cnt.values())
candidates = [k for k, v in cnt.items() if v == max_count]
return min(candidates)
def build_tree(data, depth, max_depth):
labels = [d[2] for d in data]
if depth >= max_depth or len(data) == 0:
return ('leaf', majority_label(labels))
current_gini = gini(labels)
if current_gini == 0.0:
return ('leaf', majority_label(labels))
best_wg = None
best_feature = None
best_threshold = None
n = len(data)
for feature in [0, 1]:
for t in [-3, -2, -1, 0, 1, 2, 3]:
left = [d for d in data if d[feature] < t]
right = [d for d in data if d[feature] >= t]
if not left or not right:
continue
wg = (len(left) / n) * gini([d[2] for d in left]) + \
(len(right) / n) * gini([d[2] for d in right])
if wg >= current_gini:
continue
if best_wg is None or wg < best_wg:
best_wg = wg
best_feature = feature
best_threshold = t
if best_wg is None:
return ('leaf', majority_label(labels))
left_data = [d for d in data if d[best_feature] < best_threshold]
right_data = [d for d in data if d[best_feature] >= best_threshold]
return ('node', best_feature, best_threshold,
build_tree(left_data, depth + 1, max_depth),
build_tree(right_data, depth + 1, max_depth))
def predict(tree, x1, x2):
if tree[0] == 'leaf':
return tree[1]
_, feature, threshold, left, right = tree
if (x1 if feature == 0 else x2) < threshold:
return predict(left, x1, x2)
return predict(right, x1, x2)
def main():
data = sys.stdin.read().split()
idx = 0
M = int(data[idx]); idx += 1
samples = []
for _ in range(M):
x1 = float(data[idx]); x2 = float(data[idx+1]); y = int(data[idx+2])
idx += 3
samples.append((x1, x2, y))
tx1 = float(data[idx]); tx2 = float(data[idx+1])
print(f"{gini([d[2] for d in samples]):.4f}")
tree = build_tree(samples, 1, 5)
print(predict(tree, tx1, tx2))
if __name__ == '__main__':
main()
复杂度分析
- 时间复杂度:
,其中
为树的深度(最大 5),每层对 14 种划分方案各遍历所有样本。
- 空间复杂度:
,存储训练样本和递归栈。

京公网安备 11010502036488号