DFS即可

from collections import defaultdict
N, K = map(int, input().strip().split())
tree = list(map(int, input().strip().split()))
adj = defaultdict(list)
for _ in range(N-1):
    i, j = map(int, input().strip().split())
    adj[i].append(j)
    adj[j].append(i)

def make_tree(i):
    for j in adj[i]:
        if i in adj[j]:
            adj[j].remove(i)
        make_tree(j)

root = int(input().strip())
make_tree(root)

res = -1
diff = -float('inf')
best_node = N
def dfs(node):
    global diff
    global res
    global best_node
    min_val = max_val = tree[node - 1]
    count = 1
    for i in adj[node]:
        tmin, tmax, tcount = dfs(i)
        min_val = min(tmin, min_val)
        max_val = max(tmax, max_val)
        count += tcount

    if count <= K:
        cur_diff = max_val-min_val
        if cur_diff > diff or (cur_diff == diff and best_node>node):
            res = node
            best_node = node
            diff = max_val-min_val

    return (min_val, max_val, count)

dfs(root)
print(res)