验证集可达到的最优F1值

决策树剪枝,听着很实际——训练时长得太深容易过拟合,剪掉一些子树让模型变简单。这题问的是:在所有可能的剪枝方案里,哪个方案能让验证集上的 F1 值最大?

思路

先厘清"剪枝"的含义:对任意一个内部节点,可以把它整棵子树砍掉,直接用这个节点自带的标签作为输出。可以同时剪多个节点,也可以一个都不剪。

关键观察:剪枝决策是按树结构递归的。对于节点

  • 如果剪掉 :所有到达 的样本都预测为 的标签,子树怎样无所谓。
  • 如果保留 :样本按特征分裂规则分到左右子树,左右子树各自独立做剪枝决策。

这就是一个树形 DP。对每个节点,维护它所有可能的预测结果,用 二元组表示——TP 是正类被正确预测的数量,FP 是负类被错误预测为正类的数量。为什么不用三元组 ?因为到达该节点的正类样本数是固定的,,所以 唯一确定。

保留节点时,左右子树的结果做笛卡尔积

笛卡尔积会让状态数爆炸怎么办?做 Pareto 剪枝。对于 F1 来说,TP 越大越好,FP 越小越好。如果状态 A 的 TP 不低于 B 且 FP 不高于 B(且不完全相同),那 A 一定不比 B 差,B 可以扔掉。按 TP 降序排列后,只保留 FP 严格递减的状态,就是 Pareto 前沿。

最后在根节点的所有 Pareto 最优状态里,逐个算 ,取最大值。

复杂度

  • 时间:每个节点的 Pareto 前沿大小不超过 ,其中 分别是到达该节点的正负样本数。笛卡尔积后再做 Pareto 剪枝,整体复杂度与树的结构和样本分布有关,实际远小于暴力枚举。
  • 空间:

代码

import sys

def main():
    input_data = sys.stdin.buffer.read().split()
    idx = 0
    def rd():
        nonlocal idx
        val = int(input_data[idx]); idx += 1
        return val

    N, M, K = rd(), rd(), rd()
    lc = [0] * (N + 1)
    rc = [0] * (N + 1)
    ft = [0] * (N + 1)
    th = [0] * (N + 1)
    lb = [0] * (N + 1)

    for i in range(1, N + 1):
        lc[i], rc[i], ft[i], th[i], lb[i] = rd(), rd(), rd(), rd(), rd()

    features = []
    true_labels = []
    for _ in range(M):
        f = [rd() for _ in range(K)]
        true_labels.append(rd())
        features.append(f)

    def solve(node, sids):
        pos = sum(1 for si in sids if true_labels[si] == 1)
        neg = len(sids) - pos

        # 剪枝选项:用该节点标签预测所有到达的样本
        if lb[node] == 1:
            res = [(pos, neg)]
        else:
            res = [(0, 0)]

        if lc[node] == 0 or not sids:
            return res

        # 不剪枝:按特征分裂到左右子树
        f, t = ft[node], th[node]
        ls = [si for si in sids if features[si][f - 1] <= t]
        rs = [si for si in sids if features[si][f - 1] > t]

        lo = solve(lc[node], ls)
        ro = solve(rc[node], rs)

        for ltp, lfp in lo:
            for rtp, rfp in ro:
                res.append((ltp + rtp, lfp + rfp))

        # Pareto 剪枝:TP 越大越好,FP 越小越好
        res.sort(key=lambda x: (-x[0], x[1]))
        pareto = []
        min_fp = float('inf')
        for tp, fp in res:
            if fp < min_fp:
                pareto.append((tp, fp))
                min_fp = fp
        return pareto

    total_pos = sum(true_labels)
    outcomes = solve(1, list(range(M)))

    best = 0.0
    for tp, fp in outcomes:
        fn = total_pos - tp
        denom = 2 * tp + fp + fn
        if denom > 0:
            best = max(best, 2 * tp / denom)

    print(f"{best:.6f}")

main()
#include <bits/stdc++.h>
using namespace std;

int N, M, K;
int lc[105], rc[105], ft[105], th[105], lb[105];
int feat[305][105], tl[305];

struct State { int tp, fp; };

vector<State> solve(int node, vector<int>& sids) {
    int pos = 0;
    for (int si : sids) if (tl[si] == 1) pos++;
    int neg = (int)sids.size() - pos;

    vector<State> res;
    if (lb[node] == 1) res.push_back({pos, neg});
    else res.push_back({0, 0});

    if (lc[node] == 0 || sids.empty()) return res;

    int f = ft[node], t = th[node];
    vector<int> ls, rs;
    for (int si : sids) {
        if (feat[si][f] <= t) ls.push_back(si);
        else rs.push_back(si);
    }

    auto lo = solve(lc[node], ls);
    auto ro = solve(rc[node], rs);

    for (auto& l : lo)
        for (auto& r : ro)
            res.push_back({l.tp + r.tp, l.fp + r.fp});

    sort(res.begin(), res.end(), [](const State& a, const State& b) {
        return a.tp != b.tp ? a.tp > b.tp : a.fp < b.fp;
    });
    vector<State> pareto;
    int minFP = INT_MAX;
    for (auto& s : res) {
        if (s.fp < minFP) {
            pareto.push_back(s);
            minFP = s.fp;
        }
    }
    return pareto;
}

int main() {
    scanf("%d%d%d", &N, &M, &K);
    for (int i = 1; i <= N; i++)
        scanf("%d%d%d%d%d", &lc[i], &rc[i], &ft[i], &th[i], &lb[i]);
    for (int i = 0; i < M; i++) {
        for (int j = 1; j <= K; j++) scanf("%d", &feat[i][j]);
        scanf("%d", &tl[i]);
    }

    vector<int> all(M);
    iota(all.begin(), all.end(), 0);
    int totalPos = 0;
    for (int i = 0; i < M; i++) if (tl[i] == 1) totalPos++;

    auto outcomes = solve(1, all);

    double best = 0;
    for (auto& s : outcomes) {
        int fn = totalPos - s.tp;
        int denom = 2 * s.tp + s.fp + fn;
        if (denom > 0) best = max(best, 2.0 * s.tp / denom);
    }
    printf("%.6f\n", best);
}