题目链接
题目描述
给定一棵用于二分类的二叉决策树和一个验证集。需要在这棵树上寻找一个最优的“剪枝”方案,使得模型在验证集上的 F1 分数最高。
- 决策树结构:节点分为内部节点和叶节点。内部节点根据“第
个特征
阈值”的规则将样本分发到左或右子树。叶节点直接输出一个预测类别(0或1)。
- 剪枝:可以将树中的任何一个内部节点替换为一个叶节点,其预测输出由该节点预设的
leaf_output
决定。可以选择任意多个节点进行剪枝。 - 目标:找到一种剪枝组合,使得在验证集上得到的 F1 分数最大,并输出这个最大值。
解题思路
这是一个典型的树形动态规划(Tree DP)问题。我们可以通过一次后序遍历(自底向上)来解决。对于树中的每一个节点,我们都计算出对以它为根的子树进行最优剪枝后,能得到的最佳性能指标。
由于 F1 分数是一个全局指标,不是局部可加的,我们不能直接在递归中传递或优化 F1 分数。但是,构成 F1 分数的 混淆矩阵(即 TP, FP, FN 的计数)是可加的。因此,我们的 DP 状态应该是混淆矩阵。
核心算法:后序遍历 + 动态规划
-
预处理:样本路由 在进行动态规划之前,我们首先需要知道验证集中的每个样本在不剪枝的情况下,会经过哪些节点。我们可以遍历一次所有
个验证样本,让它们在完整的决策树上走一遍,并记录下每个样本所经过的路径。这样,对于树中的任意一个节点
,我们就能得到一个列表
samples_at_node[u]
,其中包含了所有会到达节点的验证样本。
-
定义递归函数(DP核心) 我们定义一个递归函数,例如
solve(u)
,它返回一个(TP, FP, FN)
的元组(或结构体),表示对以节点为根的子树进行最优剪枝后,施加于
samples_at_node[u]
这个样本集上得到的混淆矩阵。 -
DP状态转移 函数
solve(u)
的逻辑如下:-
基本情况(Base Case):如果
是一个叶节点 该节点无法再分子,也无法剪枝。它的行为是固定的。我们遍历
samples_at_node[u]
中的所有样本,根据该叶节点的leaf_output
作为预测值,与样本的真实标签比较,计算出TP, FP, FN
,并返回。 -
递归情况(Recursive Step):如果
是一个内部节点 对于节点
,我们有两种选择: a. 剪枝(Prune):将节点
及其整个子树视为一个叶节点。预测值统一为节点
预设的
leaf_output
。我们遍历samples_at_node[u]
中的所有样本,计算出这种策略下的混淆矩阵cm_pruned = (TP_p, FP_p, FN_p)
。 b. 不剪枝(Don't Prune):保留节点的分支功能。样本集
samples_at_node[u]
会被的判断规则分裂,一部分流向左子节点
v_l
,另一部分流向右子节点v_r
。我们递归地调用solve(v_l)
和solve(v_r)
,得到左右子树的最优混淆矩阵cm_left
和cm_right
。那么,不剪枝策略下的总混淆矩阵就是cm_not_pruned = cm_left + cm_right
(各项指标直接相加)。 -
决策 现在,我们需要在“剪枝”和“不剪枝”两种策略中做出选择。我们分别基于
cm_pruned
和cm_not_pruned
计算出它们对应的 F1 分数:f1_pruned
和f1_not_pruned
。- 如果
f1_pruned > f1_not_pruned
,说明对于samples_at_node[u]
这个局部样本集,剪枝是更优的选择。因此函数返回cm_pruned
。 - 否则,不剪枝更优(或一样好),函数返回
cm_not_pruned
。
- 如果
-
-
主流程
- 执行预处理步骤,得到所有节点的样本归属列表。
- 从根节点(通常是节点1)开始调用
solve(1)
。该调用将返回整棵树在最优剪枝策略下的全局混淆矩阵。 - 利用这个最终的混淆矩阵计算出全局最优的 F1 分数。
- 使用记忆化(Memoization)来缓存
solve(u)
的计算结果,避免对同一节点重复计算,从而优化性能。
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <iomanip>
#include <map>
using namespace std;
struct Node {
int left = 0, right = 0;
int feature_idx = 0;
int threshold = 0;
int leaf_output = 0;
};
struct Sample {
vector<int> features;
int true_label;
};
struct ConfusionMatrix {
int tp = 0, fp = 0, fn = 0;
};
vector<Node> tree;
vector<Sample> validation_set;
vector<vector<int>> samples_at_node;
map<int, ConfusionMatrix> memo;
double calculate_f1(const ConfusionMatrix& cm) {
if (cm.tp == 0) return 0.0;
double precision = (double)cm.tp / (cm.tp + cm.fp);
double recall = (double)cm.tp / (cm.tp + cm.fn);
if (precision + recall == 0) return 0.0;
return 2 * precision * recall / (precision + recall);
}
ConfusionMatrix solve(int node_id) {
if (memo.count(node_id)) {
return memo[node_id];
}
Node& current_node = tree[node_id];
// 基本情况:叶节点
if (current_node.left == 0 && current_node.right == 0) {
ConfusionMatrix cm;
for (int sample_idx : samples_at_node[node_id]) {
int pred = current_node.leaf_output;
int true_label = validation_set[sample_idx].true_label;
if (pred == 1 && true_label == 1) cm.tp++;
else if (pred == 1 && true_label == 0) cm.fp++;
else if (pred == 0 && true_label == 1) cm.fn++;
}
return memo[node_id] = cm;
}
// 递归情况:内部节点
// 选项1:剪枝
ConfusionMatrix cm_pruned;
for (int sample_idx : samples_at_node[node_id]) {
int pred = current_node.leaf_output;
int true_label = validation_set[sample_idx].true_label;
if (pred == 1 && true_label == 1) cm_pruned.tp++;
else if (pred == 1 && true_label == 0) cm_pruned.fp++;
else if (pred == 0 && true_label == 1) cm_pruned.fn++;
}
double f1_pruned = calculate_f1(cm_pruned);
// 选项2:不剪枝
ConfusionMatrix cm_left = solve(current_node.left);
ConfusionMatrix cm_right = solve(current_node.right);
ConfusionMatrix cm_not_pruned = {cm_left.tp + cm_right.tp, cm_left.fp + cm_right.fp, cm_left.fn + cm_right.fn};
double f1_not_pruned = calculate_f1(cm_not_pruned);
if (f1_pruned > f1_not_pruned) {
return memo[node_id] = cm_pruned;
} else {
return memo[node_id] = cm_not_pruned;
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int n, m, k;
cin >> n >> m >> k;
tree.resize(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> tree[i].left >> tree[i].right >> tree[i].feature_idx >> tree[i].threshold >> tree[i].leaf_output;
}
validation_set.resize(m);
for (int i = 0; i < m; ++i) {
validation_set[i].features.resize(k);
for (int j = 0; j < k; ++j) {
cin >> validation_set[i].features[j];
}
cin >> validation_set[i].true_label;
}
// 预处理:样本路由
samples_at_node.resize(n + 1);
for (int i = 0; i < m; ++i) {
int current_node_id = 1;
while (tree[current_node_id].left != 0) { // 只要是内部节点
samples_at_node[current_node_id].push_back(i);
const auto& node = tree[current_node_id];
if (validation_set[i].features[node.feature_idx - 1] <= node.threshold) {
current_node_id = node.left;
} else {
current_node_id = node.right;
}
}
samples_at_node[current_node_id].push_back(i); // 到达叶子
}
ConfusionMatrix final_cm = solve(1);
double max_f1 = calculate_f1(final_cm);
cout << fixed << setprecision(6) << max_f1 << "\n";
return 0;
}
import java.util.*;
class Node {
int left, right, featureIdx, threshold, leafOutput;
}
class Sample {
int[] features;
int trueLabel;
}
class ConfusionMatrix {
int tp = 0, fp = 0, fn = 0;
}
public class Main {
static Node[] tree;
static Sample[] validationSet;
static List<Integer>[] samplesAtNode;
static Map<Integer, ConfusionMatrix> memo = new HashMap<>();
private static double calculateF1(ConfusionMatrix cm) {
if (cm.tp == 0) return 0.0;
double precision = (double) cm.tp / (cm.tp + cm.fp);
double recall = (double) cm.tp / (cm.tp + cm.fn);
if (precision + recall == 0) return 0.0;
return 2 * precision * recall / (precision + recall);
}
private static ConfusionMatrix solve(int nodeId) {
if (memo.containsKey(nodeId)) {
return memo.get(nodeId);
}
Node currentNode = tree[nodeId];
// 基本情况:叶节点
if (currentNode.left == 0) {
ConfusionMatrix cm = new ConfusionMatrix();
for (int sampleIdx : samplesAtNode[nodeId]) {
int pred = currentNode.leafOutput;
int trueLabel = validationSet[sampleIdx].trueLabel;
if (pred == 1 && trueLabel == 1) cm.tp++;
else if (pred == 1 && trueLabel == 0) cm.fp++;
else if (pred == 0 && trueLabel == 1) cm.fn++;
}
memo.put(nodeId, cm);
return cm;
}
// 递归情况:内部节点
// 选项1:剪枝
ConfusionMatrix cmPruned = new ConfusionMatrix();
for (int sampleIdx : samplesAtNode[nodeId]) {
int pred = currentNode.leafOutput;
int trueLabel = validationSet[sampleIdx].trueLabel;
if (pred == 1 && trueLabel == 1) cmPruned.tp++;
else if (pred == 1 && trueLabel == 0) cmPruned.fp++;
else if (pred == 0 && trueLabel == 1) cmPruned.fn++;
}
double f1Pruned = calculateF1(cmPruned);
// 选项2:不剪枝
ConfusionMatrix cmLeft = solve(currentNode.left);
ConfusionMatrix cmRight = solve(currentNode.right);
ConfusionMatrix cmNotPruned = new ConfusionMatrix();
cmNotPruned.tp = cmLeft.tp + cmRight.tp;
cmNotPruned.fp = cmLeft.fp + cmRight.fp;
cmNotPruned.fn = cmLeft.fn + cmRight.fn;
double f1NotPruned = calculateF1(cmNotPruned);
ConfusionMatrix result = (f1Pruned > f1NotPruned) ? cmPruned : cmNotPruned;
memo.put(nodeId, result);
return result;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
int k = sc.nextInt();
tree = new Node[n + 1];
for (int i = 1; i <= n; i++) {
tree[i] = new Node();
tree[i].left = sc.nextInt();
tree[i].right = sc.nextInt();
tree[i].featureIdx = sc.nextInt();
tree[i].threshold = sc.nextInt();
tree[i].leafOutput = sc.nextInt();
}
validationSet = new Sample[m];
for (int i = 0; i < m; i++) {
validationSet[i] = new Sample();
validationSet[i].features = new int[k];
for (int j = 0; j < k; j++) {
validationSet[i].features[j] = sc.nextInt();
}
validationSet[i].trueLabel = sc.nextInt();
}
// 预处理:样本路由
samplesAtNode = new ArrayList[n + 1];
for (int i = 0; i <= n; i++) {
samplesAtNode[i] = new ArrayList<>();
}
for (int i = 0; i < m; i++) {
int currentNodeId = 1;
while (tree[currentNodeId].left != 0) {
samplesAtNode[currentNodeId].add(i);
Node node = tree[currentNodeId];
if (validationSet[i].features[node.featureIdx - 1] <= node.threshold) {
currentNodeId = node.left;
} else {
currentNodeId = node.right;
}
}
samplesAtNode[currentNodeId].add(i);
}
ConfusionMatrix finalCm = solve(1);
double maxF1 = calculateF1(finalCm);
System.out.printf("%.6f\n", maxF1);
}
}
import sys
memo = {}
tree = {}
validation_set = []
samples_at_node = {}
def calculate_f1(cm):
"""根据混淆矩阵计算F1分数"""
tp, fp, fn = cm['tp'], cm['fp'], cm['fn']
if tp == 0:
return 0.0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def solve(node_id):
"""递归函数,计算以node_id为根的子树的最优混淆矩阵"""
if node_id in memo:
return memo[node_id]
current_node = tree[node_id]
# 基本情况:叶节点
if current_node['left'] == 0:
cm = {'tp': 0, 'fp': 0, 'fn': 0}
for sample_idx in samples_at_node.get(node_id, []):
pred = current_node['leaf_output']
true_label = validation_set[sample_idx]['true_label']
if pred == 1 and true_label == 1: cm['tp'] += 1
elif pred == 1 and true_label == 0: cm['fp'] += 1
elif pred == 0 and true_label == 1: cm['fn'] += 1
memo[node_id] = cm
return cm
# 递归情况:内部节点
# 选项1:剪枝
cm_pruned = {'tp': 0, 'fp': 0, 'fn': 0}
for sample_idx in samples_at_node.get(node_id, []):
pred = current_node['leaf_output']
true_label = validation_set[sample_idx]['true_label']
if pred == 1 and true_label == 1: cm_pruned['tp'] += 1
elif pred == 1 and true_label == 0: cm_pruned['fp'] += 1
elif pred == 0 and true_label == 1: cm_pruned['fn'] += 1
f1_pruned = calculate_f1(cm_pruned)
# 选项2:不剪枝
cm_left = solve(current_node['left'])
cm_right = solve(current_node['right'])
cm_not_pruned = {
'tp': cm_left['tp'] + cm_right['tp'],
'fp': cm_left['fp'] + cm_right['fp'],
'fn': cm_left['fn'] + cm_right['fn']
}
f1_not_pruned = calculate_f1(cm_not_pruned)
if f1_pruned > f1_not_pruned:
result = cm_pruned
else:
result = cm_not_pruned
memo[node_id] = result
return result
def main():
sys.setrecursionlimit(2000) # 防止递归深度超限
n, m, k = map(int, input().split())
for i in range(1, n + 1):
l, r, f_idx, t, out = map(int, input().split())
tree[i] = {'left': l, 'right': r, 'feature_idx': f_idx, 'threshold': t, 'leaf_output': out}
for _ in range(m):
line = list(map(int, input().split()))
validation_set.append({'features': line[:-1], 'true_label': line[-1]})
# 预处理:样本路由
for i in range(n + 1):
samples_at_node[i] = []
for i in range(m):
current_node_id = 1
while tree[current_node_id]['left'] != 0:
samples_at_node[current_node_id].append(i)
node = tree[current_node_id]
if validation_set[i]['features'][node['feature_idx'] - 1] <= node['threshold']:
current_node_id = node['left']
else:
current_node_id = node['right']
samples_at_node[current_node_id].append(i)
final_cm = solve(1)
max_f1 = calculate_f1(final_cm)
print(f"{max_f1:.6f}")
if __name__ == "__main__":
main()
算法及复杂度
- 算法:树形动态规划 (Tree DP) + 记忆化搜索
- 时间复杂度:
,其中
是节点数,
是验证样本数,
是树的最大深度。
- 样本路由:每个样本最多遍历树的深度一次,总时间为
。
- 动态规划:
solve
函数会对每个节点计算一次。在每个节点的计算中,都需要遍历
samples_at_node[u]
中的所有样本。由于一个样本最多属于个节点的
samples_at_node
列表,所有列表的总样本数是。因此,DP部分的总计算量与此成正比。在最坏情况下(链状树, D=N),总复杂度近似为
。
- 样本路由:每个样本最多遍历树的深度一次,总时间为
- 空间复杂度:
。
存储树结构和记忆化表。
存储
samples_at_node
列表。存储验证集数据。 在最坏情况下,空间复杂度也近似为
。