Description
给出一棵树, 个询问,每次询问给出 ,需要得到有多少个点距离 相等。
Solution
分类讨论
- 时,每个点都可以,输出
- 为奇数时, 显然不存在点满足条件输出
剩余情况我们考虑结合 的特点
如果距离 相等的点在 上,如下图所示
从他们的 开始的其他点都满足条件,即答案为
其中 是 往上跳 的点, 是 往上跳 的点如果距离 相等的点不在 上,如下图所示
我们选择两者中深度较大的点(), 往上跳 到 ,往上跳 到 ,答案即为
于是归纳出答案,上述两种情况分别为
Code
#include<bits/stdc++.h> const int N = 5e5 + 5; typedef long long ll; std::vector<int> G[N]; int n, m; int sz[N], fa[N][20], dep[N]; void dfs(int u, int par) { sz[u] = 1; for(auto &v : G[u]) { if(v == par) continue; dfs(v, u); sz[u] += sz[v]; } } void bfs(int root) { std::queue<int> q; q.push(root); dep[root] = 0; fa[root][0] = root; while(!q.empty()) { auto tmp = q.front(); q.pop(); for(int i = 1; i < 20; i++) { fa[tmp][i] = fa[fa[tmp][i - 1]][i - 1]; } for(auto &v : G[tmp]) { if(v == fa[tmp][0]) continue; fa[v][0] = tmp; dep[v] = dep[tmp] + 1; q.push(v); } } } int LCA(int x, int y) { if(dep[x] > dep[y]) std::swap(x, y); int hu = dep[x], hv = dep[y]; int tu = x, tv = y; for(int det = hv - hu, i = 0; det; det >>= 1, i++) { if(det & 1) { tv = fa[tv][i]; } } if(tu == tv) return tu; for(int i = 19; i >= 0; --i) { if(fa[tu][i] == fa[tv][i]) continue; tu = fa[tu][i]; tv = fa[tv][i]; } return fa[tu][0]; } int cal(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; } int get(int x, int det) { for(int i = 0; det; det >>= 1, i++) { if(det & 1) x = fa[x][i]; } return x; } int main() { std::ios::sync_with_stdio(false), std::cin.tie(nullptr), std::cout.tie(nullptr); std::cin >> n; for(int i = 1; i <= n - 1; i++) { int u, v; std::cin >> u >> v; G[u].push_back(v), G[v].push_back(u); } dfs(1, 0), bfs(1); std::cin >> m; while(m--) { int qx, qy; std::cin >> qx >> qy; int dist = cal(qx, qy); if(qx == qy) { std::cout << n << '\n'; } else if(dist & 1) { std::cout << "0\n"; } else { if(dep[qx] < dep[qy]) { std::swap(qx, qy); } int anc = LCA(qx, qy); int upx = get(qx, dist >> 1), upy = get(qy, dist >> 1); if(upx == anc) { int L = get(qx, dist / 2 - 1), R = get(qy, dist / 2 - 1); std::cout << n - sz[L] - sz[R] << '\n'; } else { int L = get(qx, dist / 2), R = get(qx, dist / 2 - 1); std::cout << sz[L] - sz[R] << '\n'; } } } return 0; }