验证集可达到的最优F1值
决策树剪枝,听着很实际——训练时长得太深容易过拟合,剪掉一些子树让模型变简单。这题问的是:在所有可能的剪枝方案里,哪个方案能让验证集上的 F1 值最大?
思路
先厘清"剪枝"的含义:对任意一个内部节点,可以把它整棵子树砍掉,直接用这个节点自带的标签作为输出。可以同时剪多个节点,也可以一个都不剪。
关键观察:剪枝决策是按树结构递归的。对于节点 :
- 如果剪掉
:所有到达
的样本都预测为
的标签,子树怎样无所谓。
- 如果保留
:样本按特征分裂规则分到左右子树,左右子树各自独立做剪枝决策。
这就是一个树形 DP。对每个节点,维护它所有可能的预测结果,用 二元组表示——TP 是正类被正确预测的数量,FP 是负类被错误预测为正类的数量。为什么不用三元组
?因为到达该节点的正类样本数是固定的,
,所以
由
唯一确定。
保留节点时,左右子树的结果做笛卡尔积:,
。
笛卡尔积会让状态数爆炸怎么办?做 Pareto 剪枝。对于 F1 来说,TP 越大越好,FP 越小越好。如果状态 A 的 TP 不低于 B 且 FP 不高于 B(且不完全相同),那 A 一定不比 B 差,B 可以扔掉。按 TP 降序排列后,只保留 FP 严格递减的状态,就是 Pareto 前沿。
最后在根节点的所有 Pareto 最优状态里,逐个算 ,取最大值。
复杂度
- 时间:每个节点的 Pareto 前沿大小不超过
,其中
分别是到达该节点的正负样本数。笛卡尔积后再做 Pareto 剪枝,整体复杂度与树的结构和样本分布有关,实际远小于暴力枚举。
- 空间:
代码
import sys
def main():
input_data = sys.stdin.buffer.read().split()
idx = 0
def rd():
nonlocal idx
val = int(input_data[idx]); idx += 1
return val
N, M, K = rd(), rd(), rd()
lc = [0] * (N + 1)
rc = [0] * (N + 1)
ft = [0] * (N + 1)
th = [0] * (N + 1)
lb = [0] * (N + 1)
for i in range(1, N + 1):
lc[i], rc[i], ft[i], th[i], lb[i] = rd(), rd(), rd(), rd(), rd()
features = []
true_labels = []
for _ in range(M):
f = [rd() for _ in range(K)]
true_labels.append(rd())
features.append(f)
def solve(node, sids):
pos = sum(1 for si in sids if true_labels[si] == 1)
neg = len(sids) - pos
# 剪枝选项:用该节点标签预测所有到达的样本
if lb[node] == 1:
res = [(pos, neg)]
else:
res = [(0, 0)]
if lc[node] == 0 or not sids:
return res
# 不剪枝:按特征分裂到左右子树
f, t = ft[node], th[node]
ls = [si for si in sids if features[si][f - 1] <= t]
rs = [si for si in sids if features[si][f - 1] > t]
lo = solve(lc[node], ls)
ro = solve(rc[node], rs)
for ltp, lfp in lo:
for rtp, rfp in ro:
res.append((ltp + rtp, lfp + rfp))
# Pareto 剪枝:TP 越大越好,FP 越小越好
res.sort(key=lambda x: (-x[0], x[1]))
pareto = []
min_fp = float('inf')
for tp, fp in res:
if fp < min_fp:
pareto.append((tp, fp))
min_fp = fp
return pareto
total_pos = sum(true_labels)
outcomes = solve(1, list(range(M)))
best = 0.0
for tp, fp in outcomes:
fn = total_pos - tp
denom = 2 * tp + fp + fn
if denom > 0:
best = max(best, 2 * tp / denom)
print(f"{best:.6f}")
main()
#include <bits/stdc++.h>
using namespace std;
int N, M, K;
int lc[105], rc[105], ft[105], th[105], lb[105];
int feat[305][105], tl[305];
struct State { int tp, fp; };
vector<State> solve(int node, vector<int>& sids) {
int pos = 0;
for (int si : sids) if (tl[si] == 1) pos++;
int neg = (int)sids.size() - pos;
vector<State> res;
if (lb[node] == 1) res.push_back({pos, neg});
else res.push_back({0, 0});
if (lc[node] == 0 || sids.empty()) return res;
int f = ft[node], t = th[node];
vector<int> ls, rs;
for (int si : sids) {
if (feat[si][f] <= t) ls.push_back(si);
else rs.push_back(si);
}
auto lo = solve(lc[node], ls);
auto ro = solve(rc[node], rs);
for (auto& l : lo)
for (auto& r : ro)
res.push_back({l.tp + r.tp, l.fp + r.fp});
sort(res.begin(), res.end(), [](const State& a, const State& b) {
return a.tp != b.tp ? a.tp > b.tp : a.fp < b.fp;
});
vector<State> pareto;
int minFP = INT_MAX;
for (auto& s : res) {
if (s.fp < minFP) {
pareto.push_back(s);
minFP = s.fp;
}
}
return pareto;
}
int main() {
scanf("%d%d%d", &N, &M, &K);
for (int i = 1; i <= N; i++)
scanf("%d%d%d%d%d", &lc[i], &rc[i], &ft[i], &th[i], &lb[i]);
for (int i = 0; i < M; i++) {
for (int j = 1; j <= K; j++) scanf("%d", &feat[i][j]);
scanf("%d", &tl[i]);
}
vector<int> all(M);
iota(all.begin(), all.end(), 0);
int totalPos = 0;
for (int i = 0; i < M; i++) if (tl[i] == 1) totalPos++;
auto outcomes = solve(1, all);
double best = 0;
for (auto& s : outcomes) {
int fn = totalPos - s.tp;
int denom = 2 * s.tp + s.fp + fn;
if (denom > 0) best = max(best, 2.0 * s.tp / denom);
}
printf("%.6f\n", best);
}

京公网安备 11010502036488号