题目链接

验证集可达到的最优F1值

题目描述

给定一棵用于二分类的二叉决策树和一个验证集。需要在这棵树上寻找一个最优的“剪枝”方案,使得模型在验证集上的 F1 分数最高。

  • 决策树结构:节点分为内部节点和叶节点。内部节点根据“第 个特征 阈值”的规则将样本分发到左或右子树。叶节点直接输出一个预测类别(0或1)。
  • 剪枝:可以将树中的任何一个内部节点替换为一个叶节点,其预测输出由该节点预设的 leaf_output 决定。可以选择任意多个节点进行剪枝。
  • 目标:找到一种剪枝组合,使得在验证集上得到的 F1 分数最大,并输出这个最大值。

解题思路

这是一个典型的树形动态规划(Tree DP)问题。我们可以通过一次后序遍历(自底向上)来解决。对于树中的每一个节点,我们都计算出对以它为根的子树进行最优剪枝后,能得到的最佳性能指标。

由于 F1 分数是一个全局指标,不是局部可加的,我们不能直接在递归中传递或优化 F1 分数。但是,构成 F1 分数的 混淆矩阵(即 TP, FP, FN 的计数)是可加的。因此,我们的 DP 状态应该是混淆矩阵。

核心算法:后序遍历 + 动态规划

  1. 预处理:样本路由 在进行动态规划之前,我们首先需要知道验证集中的每个样本在不剪枝的情况下,会经过哪些节点。我们可以遍历一次所有 个验证样本,让它们在完整的决策树上走一遍,并记录下每个样本所经过的路径。这样,对于树中的任意一个节点 ,我们就能得到一个列表 samples_at_node[u],其中包含了所有会到达节点 的验证样本。

  2. 定义递归函数(DP核心) 我们定义一个递归函数,例如 solve(u),它返回一个 (TP, FP, FN) 的元组(或结构体),表示对以节点 为根的子树进行最优剪枝后,施加于 samples_at_node[u] 这个样本集上得到的混淆矩阵。

  3. DP状态转移 函数 solve(u) 的逻辑如下:

    • 基本情况(Base Case):如果 是一个叶节点 该节点无法再分子,也无法剪枝。它的行为是固定的。我们遍历 samples_at_node[u] 中的所有样本,根据该叶节点的 leaf_output 作为预测值,与样本的真实标签比较,计算出 TP, FP, FN,并返回。

    • 递归情况(Recursive Step):如果 是一个内部节点 对于节点 ,我们有两种选择: a. 剪枝(Prune):将节点 及其整个子树视为一个叶节点。预测值统一为节点 预设的 leaf_output。我们遍历 samples_at_node[u] 中的所有样本,计算出这种策略下的混淆矩阵 cm_pruned = (TP_p, FP_p, FN_p)。 b. 不剪枝(Don't Prune):保留节点 的分支功能。样本集 samples_at_node[u] 会被 的判断规则分裂,一部分流向左子节点 v_l,另一部分流向右子节点 v_r。我们递归地调用 solve(v_l)solve(v_r),得到左右子树的最优混淆矩阵 cm_leftcm_right。那么,不剪枝策略下的总混淆矩阵就是 cm_not_pruned = cm_left + cm_right(各项指标直接相加)。

    • 决策 现在,我们需要在“剪枝”和“不剪枝”两种策略中做出选择。我们分别基于 cm_prunedcm_not_pruned 计算出它们对应的 F1 分数:f1_prunedf1_not_pruned

      • 如果 f1_pruned > f1_not_pruned,说明对于 samples_at_node[u] 这个局部样本集,剪枝是更优的选择。因此函数返回 cm_pruned
      • 否则,不剪枝更优(或一样好),函数返回 cm_not_pruned
  4. 主流程

    • 执行预处理步骤,得到所有节点的样本归属列表。
    • 从根节点(通常是节点1)开始调用 solve(1)。该调用将返回整棵树在最优剪枝策略下的全局混淆矩阵。
    • 利用这个最终的混淆矩阵计算出全局最优的 F1 分数。
    • 使用记忆化(Memoization)来缓存 solve(u) 的计算结果,避免对同一节点重复计算,从而优化性能。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <iomanip>
#include <map>

using namespace std;

struct Node {
    int left = 0, right = 0;
    int feature_idx = 0;
    int threshold = 0;
    int leaf_output = 0;
};

struct Sample {
    vector<int> features;
    int true_label;
};

struct ConfusionMatrix {
    int tp = 0, fp = 0, fn = 0;
};

vector<Node> tree;
vector<Sample> validation_set;
vector<vector<int>> samples_at_node;
map<int, ConfusionMatrix> memo;

double calculate_f1(const ConfusionMatrix& cm) {
    if (cm.tp == 0) return 0.0;
    double precision = (double)cm.tp / (cm.tp + cm.fp);
    double recall = (double)cm.tp / (cm.tp + cm.fn);
    if (precision + recall == 0) return 0.0;
    return 2 * precision * recall / (precision + recall);
}

ConfusionMatrix solve(int node_id) {
    if (memo.count(node_id)) {
        return memo[node_id];
    }

    Node& current_node = tree[node_id];

    // 基本情况:叶节点
    if (current_node.left == 0 && current_node.right == 0) {
        ConfusionMatrix cm;
        for (int sample_idx : samples_at_node[node_id]) {
            int pred = current_node.leaf_output;
            int true_label = validation_set[sample_idx].true_label;
            if (pred == 1 && true_label == 1) cm.tp++;
            else if (pred == 1 && true_label == 0) cm.fp++;
            else if (pred == 0 && true_label == 1) cm.fn++;
        }
        return memo[node_id] = cm;
    }

    // 递归情况:内部节点
    // 选项1:剪枝
    ConfusionMatrix cm_pruned;
    for (int sample_idx : samples_at_node[node_id]) {
        int pred = current_node.leaf_output;
        int true_label = validation_set[sample_idx].true_label;
        if (pred == 1 && true_label == 1) cm_pruned.tp++;
        else if (pred == 1 && true_label == 0) cm_pruned.fp++;
        else if (pred == 0 && true_label == 1) cm_pruned.fn++;
    }
    double f1_pruned = calculate_f1(cm_pruned);

    // 选项2:不剪枝
    ConfusionMatrix cm_left = solve(current_node.left);
    ConfusionMatrix cm_right = solve(current_node.right);
    ConfusionMatrix cm_not_pruned = {cm_left.tp + cm_right.tp, cm_left.fp + cm_right.fp, cm_left.fn + cm_right.fn};
    double f1_not_pruned = calculate_f1(cm_not_pruned);

    if (f1_pruned > f1_not_pruned) {
        return memo[node_id] = cm_pruned;
    } else {
        return memo[node_id] = cm_not_pruned;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m, k;
    cin >> n >> m >> k;

    tree.resize(n + 1);
    for (int i = 1; i <= n; ++i) {
        cin >> tree[i].left >> tree[i].right >> tree[i].feature_idx >> tree[i].threshold >> tree[i].leaf_output;
    }

    validation_set.resize(m);
    for (int i = 0; i < m; ++i) {
        validation_set[i].features.resize(k);
        for (int j = 0; j < k; ++j) {
            cin >> validation_set[i].features[j];
        }
        cin >> validation_set[i].true_label;
    }

    // 预处理:样本路由
    samples_at_node.resize(n + 1);
    for (int i = 0; i < m; ++i) {
        int current_node_id = 1;
        while (tree[current_node_id].left != 0) { // 只要是内部节点
            samples_at_node[current_node_id].push_back(i);
            const auto& node = tree[current_node_id];
            if (validation_set[i].features[node.feature_idx - 1] <= node.threshold) {
                current_node_id = node.left;
            } else {
                current_node_id = node.right;
            }
        }
        samples_at_node[current_node_id].push_back(i); // 到达叶子
    }
    
    ConfusionMatrix final_cm = solve(1);
    double max_f1 = calculate_f1(final_cm);

    cout << fixed << setprecision(6) << max_f1 << "\n";

    return 0;
}
import java.util.*;

class Node {
    int left, right, featureIdx, threshold, leafOutput;
}

class Sample {
    int[] features;
    int trueLabel;
}

class ConfusionMatrix {
    int tp = 0, fp = 0, fn = 0;
}

public class Main {
    static Node[] tree;
    static Sample[] validationSet;
    static List<Integer>[] samplesAtNode;
    static Map<Integer, ConfusionMatrix> memo = new HashMap<>();

    private static double calculateF1(ConfusionMatrix cm) {
        if (cm.tp == 0) return 0.0;
        double precision = (double) cm.tp / (cm.tp + cm.fp);
        double recall = (double) cm.tp / (cm.tp + cm.fn);
        if (precision + recall == 0) return 0.0;
        return 2 * precision * recall / (precision + recall);
    }

    private static ConfusionMatrix solve(int nodeId) {
        if (memo.containsKey(nodeId)) {
            return memo.get(nodeId);
        }

        Node currentNode = tree[nodeId];
        
        // 基本情况:叶节点
        if (currentNode.left == 0) {
            ConfusionMatrix cm = new ConfusionMatrix();
            for (int sampleIdx : samplesAtNode[nodeId]) {
                int pred = currentNode.leafOutput;
                int trueLabel = validationSet[sampleIdx].trueLabel;
                if (pred == 1 && trueLabel == 1) cm.tp++;
                else if (pred == 1 && trueLabel == 0) cm.fp++;
                else if (pred == 0 && trueLabel == 1) cm.fn++;
            }
            memo.put(nodeId, cm);
            return cm;
        }

        // 递归情况:内部节点
        // 选项1:剪枝
        ConfusionMatrix cmPruned = new ConfusionMatrix();
        for (int sampleIdx : samplesAtNode[nodeId]) {
            int pred = currentNode.leafOutput;
            int trueLabel = validationSet[sampleIdx].trueLabel;
            if (pred == 1 && trueLabel == 1) cmPruned.tp++;
            else if (pred == 1 && trueLabel == 0) cmPruned.fp++;
            else if (pred == 0 && trueLabel == 1) cmPruned.fn++;
        }
        double f1Pruned = calculateF1(cmPruned);

        // 选项2:不剪枝
        ConfusionMatrix cmLeft = solve(currentNode.left);
        ConfusionMatrix cmRight = solve(currentNode.right);
        ConfusionMatrix cmNotPruned = new ConfusionMatrix();
        cmNotPruned.tp = cmLeft.tp + cmRight.tp;
        cmNotPruned.fp = cmLeft.fp + cmRight.fp;
        cmNotPruned.fn = cmLeft.fn + cmRight.fn;
        double f1NotPruned = calculateF1(cmNotPruned);

        ConfusionMatrix result = (f1Pruned > f1NotPruned) ? cmPruned : cmNotPruned;
        memo.put(nodeId, result);
        return result;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int k = sc.nextInt();

        tree = new Node[n + 1];
        for (int i = 1; i <= n; i++) {
            tree[i] = new Node();
            tree[i].left = sc.nextInt();
            tree[i].right = sc.nextInt();
            tree[i].featureIdx = sc.nextInt();
            tree[i].threshold = sc.nextInt();
            tree[i].leafOutput = sc.nextInt();
        }

        validationSet = new Sample[m];
        for (int i = 0; i < m; i++) {
            validationSet[i] = new Sample();
            validationSet[i].features = new int[k];
            for (int j = 0; j < k; j++) {
                validationSet[i].features[j] = sc.nextInt();
            }
            validationSet[i].trueLabel = sc.nextInt();
        }

        // 预处理:样本路由
        samplesAtNode = new ArrayList[n + 1];
        for (int i = 0; i <= n; i++) {
            samplesAtNode[i] = new ArrayList<>();
        }

        for (int i = 0; i < m; i++) {
            int currentNodeId = 1;
            while (tree[currentNodeId].left != 0) {
                samplesAtNode[currentNodeId].add(i);
                Node node = tree[currentNodeId];
                if (validationSet[i].features[node.featureIdx - 1] <= node.threshold) {
                    currentNodeId = node.left;
                } else {
                    currentNodeId = node.right;
                }
            }
            samplesAtNode[currentNodeId].add(i);
        }

        ConfusionMatrix finalCm = solve(1);
        double maxF1 = calculateF1(finalCm);

        System.out.printf("%.6f\n", maxF1);
    }
}
import sys

memo = {}
tree = {}
validation_set = []
samples_at_node = {}

def calculate_f1(cm):
    """根据混淆矩阵计算F1分数"""
    tp, fp, fn = cm['tp'], cm['fp'], cm['fn']
    if tp == 0:
        return 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    if precision + recall == 0:
        return 0.0
    return 2 * precision * recall / (precision + recall)

def solve(node_id):
    """递归函数,计算以node_id为根的子树的最优混淆矩阵"""
    if node_id in memo:
        return memo[node_id]

    current_node = tree[node_id]
    
    # 基本情况:叶节点
    if current_node['left'] == 0:
        cm = {'tp': 0, 'fp': 0, 'fn': 0}
        for sample_idx in samples_at_node.get(node_id, []):
            pred = current_node['leaf_output']
            true_label = validation_set[sample_idx]['true_label']
            if pred == 1 and true_label == 1: cm['tp'] += 1
            elif pred == 1 and true_label == 0: cm['fp'] += 1
            elif pred == 0 and true_label == 1: cm['fn'] += 1
        memo[node_id] = cm
        return cm

    # 递归情况:内部节点
    # 选项1:剪枝
    cm_pruned = {'tp': 0, 'fp': 0, 'fn': 0}
    for sample_idx in samples_at_node.get(node_id, []):
        pred = current_node['leaf_output']
        true_label = validation_set[sample_idx]['true_label']
        if pred == 1 and true_label == 1: cm_pruned['tp'] += 1
        elif pred == 1 and true_label == 0: cm_pruned['fp'] += 1
        elif pred == 0 and true_label == 1: cm_pruned['fn'] += 1
    f1_pruned = calculate_f1(cm_pruned)
    
    # 选项2:不剪枝
    cm_left = solve(current_node['left'])
    cm_right = solve(current_node['right'])
    cm_not_pruned = {
        'tp': cm_left['tp'] + cm_right['tp'],
        'fp': cm_left['fp'] + cm_right['fp'],
        'fn': cm_left['fn'] + cm_right['fn']
    }
    f1_not_pruned = calculate_f1(cm_not_pruned)

    if f1_pruned > f1_not_pruned:
        result = cm_pruned
    else:
        result = cm_not_pruned
        
    memo[node_id] = result
    return result

def main():
    sys.setrecursionlimit(2000) # 防止递归深度超限
    n, m, k = map(int, input().split())

    for i in range(1, n + 1):
        l, r, f_idx, t, out = map(int, input().split())
        tree[i] = {'left': l, 'right': r, 'feature_idx': f_idx, 'threshold': t, 'leaf_output': out}
    
    for _ in range(m):
        line = list(map(int, input().split()))
        validation_set.append({'features': line[:-1], 'true_label': line[-1]})

    # 预处理:样本路由
    for i in range(n + 1):
        samples_at_node[i] = []

    for i in range(m):
        current_node_id = 1
        while tree[current_node_id]['left'] != 0:
            samples_at_node[current_node_id].append(i)
            node = tree[current_node_id]
            if validation_set[i]['features'][node['feature_idx'] - 1] <= node['threshold']:
                current_node_id = node['left']
            else:
                current_node_id = node['right']
        samples_at_node[current_node_id].append(i)
    
    final_cm = solve(1)
    max_f1 = calculate_f1(final_cm)

    print(f"{max_f1:.6f}")

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:树形动态规划 (Tree DP) + 记忆化搜索
  • 时间复杂度:,其中 是节点数, 是验证样本数, 是树的最大深度。
    • 样本路由:每个样本最多遍历树的深度一次,总时间为
    • 动态规划solve 函数会对每个节点计算一次。在每个节点 的计算中,都需要遍历 samples_at_node[u] 中的所有样本。由于一个样本最多属于 个节点的 samples_at_node 列表,所有列表的总样本数是 。因此,DP部分的总计算量与此成正比。在最坏情况下(链状树, D=N),总复杂度近似为
  • 空间复杂度:
    • 存储树结构和记忆化表。
    • 存储 samples_at_node 列表。
    • 存储验证集数据。 在最坏情况下,空间复杂度也近似为