C
题目描述
题意是求对于任意两个点,将树相连后这两个点最远的距离再求和。
做法
换根 DP。
先分类讨论一下:
- 对于同一个树的两个点,距离是固定的,我们需要算出每个点作为根节点到所有节点的距离的和,通过换根 实现,最后将所有点求出来的距离总和要除以 2 ,就是这部分答案。
- 对于不同的树上的两个点,可以发现,对于固定的 和 两点,他们一定是跑到离它最远的点再连边,所以对于每一个点都维护出向下的最长链和次长链,这个点在这颗树上的贡献都是固定的,在第一颗树的为 (+1) , 第二棵树的同理,(但只用在一边加1就行,因为要连边。)
这两部分均能在换根的时候维护。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
void solve() {
int n, m;
cin >> n >> m;
vector<vector<int>> e(n + m + 10);
for (int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
for (int i = 1; i < m; ++i) {
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
vector<ll> siz(n + m + 10), mx(n + m + 10), d(n + m + 10), mxx(n + m + 10), p(n + m + 10);
//mx: 以 x 为根往下的最长链
//mxx: 以 x 为根往下的次长链
//d: 以 x 为根往下到所有子孙节点的距离和
function<void(int, int)> dfs = [&](int x, int ff) -> void {
siz[x] = 1;
for (int y : e[x]) {
if (y == ff) continue;
dfs(y, x);
siz[x] += siz[y];
if (mx[y] + 1 >= mx[x]) {
mxx[x] = mx[x];
mx[x] = mx[y] + 1;
p[x] = y;
} else mxx[x] = max(mxx[x], mx[y] + 1);
d[x] += d[y] + siz[y];
}
};
ll ans = 0, res = 0;
function<void(int, int, ll, ll, ll, int)> dfs1 = [&](int x, int ff, ll val, ll sizval, ll mch, int op) -> void {
res += d[x] + val;
ans += (max(mx[x], mch) + op) * (op ? m : n);
for (int y : e[x]) {
if (y == ff) continue;
ll t = d[x] - d[y] - siz[y];
ll s = siz[x] - siz[y];
dfs1(y, x, val + sizval + t + s, sizval + s, max(mch, (p[x] == y ? mxx[x] : mx[x])) + 1, op);
}
};
dfs(1, 1); dfs(n + 1, n + 1);
dfs1(1, 1, 0, 0, 0, 1); dfs1(n + 1, n + 1, 0, 0, 0, 0);
res /= 2; //记得除以 2,因为这部分会重复算两边。
cout << ans + res << "\n";
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int T = 1;
//cin >> T;
while (T--) solve();
return 0;
}