考察知识点:树上 DFS、倍增、LCA
设 为颜色
的两个直径端点(即距离最远的两个点),
为颜色
中的一个点,则点
到颜色
中点的最大距离为
。
可以想象一个点到一条线段的最远距离必然出现在这个点到两个端点的距离中,在此不做严格证明。
进而我们有如下结论:
设 为颜色
的两个直径端点,
为颜色
的两个直径端点,则颜色
中的点到颜色
中点的最大距离为
。
特殊的,当某种颜色的点不足两个时,结论可以简化。
因此我们的任务就变成了查找每种颜色距离最远的两个点。
我们使用 LCA (最近公共祖先)来计算树中每两点间的距离:
其中, 表示点
的深度(即点
到根节点的距离)。
对于代码中一些变量的解释如下:
edges[i]
:点的所有连边
fa[i][j]
:点向上
步的点(即点
的第
代祖先)
dep[i]
:点的深度(到根节点的距离)
lg[i]
:(
向下取整)
mp[x]
颜色为的所有点
longest[x]
颜色的直径端点(颜色
中距离最远的两点)
时间复杂度:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef vector<pii> vpii;
typedef vector<pll> vpll;
const int N = 1e5 + 5, LOGN = 20;
vector<int> edges[N];
int c[N], fa[N][LOGN];
int dep[N], lg[N];
map<int, vi> mp;
map<int, pii> longest;
void dfs(int u, int from)
{
mp[c[u]].push_back(u); // 记录颜色为 c[u] 的点
dep[u] = dep[from] + 1;
fa[u][0] = from; // 记录父节点
for (int i = 1; 1 << i <= dep[u]; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1]; // 记录 2^i 祖先
for (int v : edges[u]) // 遍历所有子节点
if (v != from)
dfs(v, u);
}
int lca(int x, int y) // 求 x 和 y 的最近公共祖先
{
if (dep[x] < dep[y]) // 保证 x 深度大于 y
swap(x, y);
while (dep[x] > dep[y]) // 先将 x 跳到和 y 一样深
x = fa[x][lg[dep[x] - dep[y]] - 1];
if (x == y) // 如果 x 和 y 相等,说明 y 是 x 的祖先
return x;
for (int i = lg[dep[x]] - 1; i >= 0; i--) // 从 x 开始向上跳
if (fa[x][i] != fa[y][i]) // 如果 x 和 y 的 2^i 祖先不同
x = fa[x][i], y = fa[y][i]; // 同时跳
return fa[x][0];
}
int dis(int x, int y) // 求 x 和 y 的距离
{
return dep[x] + dep[y] - 2 * dep[lca(x, y)];
}
void solve()
{
int n, q;
cin >> n >> q;
for (int i = 1; i < N; i++)
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i); // 预处理 log2
for (int i = 1; i <= n; i++)
cin >> c[i];
for (int i = 0; i < n - 1; i++)
{
int u, v;
cin >> u >> v;
edges[u].push_back(v);
edges[v].push_back(u);
}
dfs(1, 0);
for (auto item : mp) // 预处理每种颜色距离最远的两个点
{
int color = item.first;
vi &nodes = item.second;
longest[color] = {nodes[0], nodes[0]};
for (int i = 1; i < nodes.size(); i++)
{
int a = dis(longest[color].first, nodes[i]);
int b = dis(longest[color].second, nodes[i]);
if (a < b)
{
swap(a, b);
swap(longest[color].first, longest[color].second);
}
if (a > dis(longest[color].first, longest[color].second))
swap(longest[color].second, nodes[i]);
}
}
while (q--)
{
int x, y, ans = 0;
cin >> x >> y;
if (mp.count(x) && mp.count(y))
{
int x1 = longest[x].first, x2 = longest[x].second;
int y1 = longest[y].first, y2 = longest[y].second;
ans = max({dis(x1, y1), dis(x1, y2), dis(x2, y1), dis(x2, y2)});
}
cout << ans << endl;
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int t = 1;
// cin >> t;
while (t--)
solve();
return 0;
}