#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
int fa[maxn], cnt[maxn][2], w[maxn];
vector<pair<int, int>> e;
int ans = -1, pans = -1;
int n, k;
int find(int i)
{
return fa[i] == i ? i : fa[i] = find(fa[i]);
}
set<int> aa;
void merge(int x, int y)
{
int fx = find(x), fy = find(y);
if (fx != fy)
{
fa[fx] = fy;
cnt[fy][0] += cnt[fx][0];
cnt[fy][1] += cnt[fx][1];
aa.insert(fy);
if (aa.count(fx))
aa.erase(fx);
}
}
int main()
{
cin >> n >> k;
for (int i = 1; i <= n; i++)
fa[i] = i;
for (int i = 1; i <= n; i++)
{
int x;
cin >> x;
cnt[i][x] = 1;
}
for (int i = 1; i <= n; i++)
cin >> w[i];
e.resize(n - 1);
for (int i = 0; i < n - 1; i++)
{
cin >> e[i].first >> e[i].second;
}
sort(e.begin(), e.end(), [&](auto &a, auto &b)
{ return max(w[a.first], w[a.second]) < max(w[b.first], w[b.second]); });
// cout << endl;
// for (int i = 0; i < n - 1; i++)
// {
// cout << e[i].first << " " << e[i].second << endl;
// }
for (int i = 0; i < n - 1; i++)
{
merge(e[i].first, e[i].second);
if (i + 1 < n - 1 && max(w[e[i].first], w[e[i].second]) == max(w[e[i + 1].first], w[e[i + 1].second]))
{
i++;
merge(e[i].first, e[i].second);
}
for (auto j : aa)
{
if (cnt[j][0]/2*2+cnt[j][1]/2*2>=k)
{
cout << max(w[e[i].first], w[e[i].second]) << endl;
return 0;
}
}
aa.clear();
}
cout << -1 << endl;
}
仅供参考和题解不一样的做法