DFS即可
from collections import defaultdict
N, K = map(int, input().strip().split())
tree = list(map(int, input().strip().split()))
adj = defaultdict(list)
for _ in range(N-1):
i, j = map(int, input().strip().split())
adj[i].append(j)
adj[j].append(i)
def make_tree(i):
for j in adj[i]:
if i in adj[j]:
adj[j].remove(i)
make_tree(j)
root = int(input().strip())
make_tree(root)
res = -1
diff = -float('inf')
best_node = N
def dfs(node):
global diff
global res
global best_node
min_val = max_val = tree[node - 1]
count = 1
for i in adj[node]:
tmin, tmax, tcount = dfs(i)
min_val = min(tmin, min_val)
max_val = max(tmax, max_val)
count += tcount
if count <= K:
cur_diff = max_val-min_val
if cur_diff > diff or (cur_diff == diff and best_node>node):
res = node
best_node = node
diff = max_val-min_val
return (min_val, max_val, count)
dfs(root)
print(res)

京公网安备 11010502036488号