节点编号和特征编号不从0开始是一大坑...
class BinTree:
def __init__(self, N, M, K, tree, samples, labels):
self.N = N
self.M = M
self.K = K
self.tree = tree
self.samples = samples
self.labels = labels
labels_pred = self.bin_classify_all()
self.f1_max = self.f1_score(labels_pred)
for ii in range(self.N):
labels_pred = self.bin_classify_all_cut(ii)
f1 = self.f1_score(labels_pred)
if f1 > self.f1_max:
self.f1_max = f1
return
def bin_classify_all(self):
labels_pred = []
for ii in range(self.M):
labels_pred.append(self.bin_classify_one(self.samples[ii]))
return labels_pred
def bin_classify_one(self, sample):
node_id = 0
while self.tree[node_id]['l'] != -1 and self.tree[node_id]['r'] != -1:
if sample[self.tree[node_id]['f']] <= self.tree[node_id]['th']:
node_id = self.tree[node_id]['l']
else:
node_id = self.tree[node_id]['r']
return self.tree[node_id]['label']
def bin_classify_all_cut(self, cut_id):
labels_pred = []
for ii in range(self.M):
labels_pred.append(self.bin_classify_one_cut(self.samples[ii], cut_id))
return labels_pred
def bin_classify_one_cut(self, sample, cut_id):
node_id = 0
while self.tree[node_id]['l'] != -1 and self.tree[node_id]['r'] != -1 and node_id != cut_id:
if sample[self.tree[node_id]['f']] <= self.tree[node_id]['th']:
node_id = self.tree[node_id]['l']
else:
node_id = self.tree[node_id]['r']
return self.tree[node_id]['label']
def f1_score(self, labels_pred):
counts = [[0.0, 0.0], [0.0, 0.0]]
for ii in range(self.M):
counts[self.labels[ii]][labels_pred[ii]] += 1.0
TP = counts[1][1]
FP = counts[0][1]
FN = counts[1][0]
if TP == 0:
f1 = 0.0
else:
pres = TP / (TP + FP)
rec = TP / (TP + FN)
f1 = 2 * pres * rec / (pres + rec)
return f1
if __name__ == '__main__':
n, m, k = map(int, input().split())
tree_read = []
for _ in range(n):
node = list(input().split())
tree_read.append(
{
'l': int(node[0]) - 1,
'r': int(node[1]) - 1,
'f': int(node[2]) - 1,
'th': int(node[3]),
'label': int(node[4])
}
)
samples_read = []
labels_read = []
for _ in range(m):
sample_read = tuple(map(int, input().split()))
samples_read.append(sample_read[0:k])
labels_read.append(sample_read[k])
# print(tree_read, samples_read, labels_read)
bin_tree = BinTree(n, m, k, tree_read, samples_read, labels_read)
print('%.6f' % bin_tree.f1_max)

京公网安备 11010502036488号