决策树是一个用于分类和回归的模型,它通过将数据集分割成更小的子集来构建树形结构。每个内部节点代表一个特征的测试,每个分支代表测试结果,而每个叶子节点则表示最终的输出类别或值。
通俗点说,就是把一堆数据按照某个特征的某个阈值去分成两份或者多份子节点,然后递归执行这种分裂直到达到某种要求。 在本题中,只需要粗暴地对节点内部的所有特征进行尝试分裂,通过计算熵与信息增益来决定使用哪个阈值进行分裂,然后重复执行这一过程即可。 这里给出熵和信息增益的公式:
其中, 是熵,
是类别的数量,
是属于类别
的样本比例。
其中, 是信息增益,
是属性
的所有可能取值,
是在属性
取值为
时的样本子集。
标准代码如下:
import math
from collections import Counter
def calculate_entropy(labels):
label_counts = Counter(labels)
total_count = len(labels)
entropy = -sum(
(count / total_count) * math.log2(count / total_count)
for count in label_counts.values()
)
return entropy
def calculate_information_gain(examples, attr, target_attr):
total_entropy = calculate_entropy([example[target_attr] for example in examples])
values = set(example[attr] for example in examples)
attr_entropy = 0
for value in values:
value_subset = [
example[target_attr] for example in examples if example[attr] == value
]
value_entropy = calculate_entropy(value_subset)
attr_entropy += (len(value_subset) / len(examples)) * value_entropy
return total_entropy - attr_entropy
def majority_class(examples, target_attr):
return Counter([example[target_attr] for example in examples]).most_common(1)[0][0]
def learn_decision_tree(examples, attributes, target_attr):
if not examples:
return "No examples"
if all(example[target_attr] == examples[0][target_attr] for example in examples):
return examples[0][target_attr]
if not attributes:
return majority_class(examples, target_attr)
gains = {
attr: calculate_information_gain(examples, attr, target_attr)
for attr in attributes
}
best_attr = max(gains, key=gains.get)
tree = {best_attr: {}}
for value in set(example[best_attr] for example in examples):
subset = [example for example in examples if example[best_attr] == value]
new_attributes = attributes.copy()
new_attributes.remove(best_attr)
subtree = learn_decision_tree(subset, new_attributes, target_attr)
tree[best_attr][value] = subtree
return tree
def print_tree(tree):
outs = []
for key, value in sorted(tree.items()):
outs.append(f"{key}:{print_tree(value) if isinstance(value, dict) else value}")
return "{" + ",".join(outs) + "}"
if __name__ == "__main__":
examples = eval(input())
attributes = eval(input())
target_attr = eval(input())
print(print_tree(learn_decision_tree(examples, attributes, target_attr)))