C

题目描述

题意是求对于任意两个点,将树相连后这两个点最远的距离再求和。

做法

换根 DP。

先分类讨论一下:

  1. 对于同一个树的两个点,距离是固定的,我们需要算出每个点作为根节点到所有节点的距离的和,通过换根 实现,最后将所有点求出来的距离总和要除以 2 ,就是这部分答案。
  2. 对于不同的树上的两个点,可以发现,对于固定的 两点,他们一定是跑到离它最远的点再连边,所以对于每一个点都维护出向下的最长链和次长链,这个点在这颗树上的贡献都是固定的,在第一颗树的为 (+1) , 第二棵树的同理,(但只用在一边加1就行,因为要连边。)

这两部分均能在换根的时候维护。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
void solve() {
    int n, m;
    cin >> n >> m;
    vector<vector<int>> e(n + m + 10);
    for (int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    for (int i = 1; i < m; ++i) {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    vector<ll> siz(n + m + 10), mx(n + m + 10), d(n + m + 10), mxx(n + m + 10), p(n + m + 10);
  //mx: 以 x 为根往下的最长链
  //mxx: 以 x 为根往下的次长链
  //d: 以 x 为根往下到所有子孙节点的距离和
    function<void(int, int)> dfs = [&](int x, int ff) -> void {
        siz[x] = 1;
        for (int y : e[x]) {
            if (y == ff) continue;
            dfs(y, x);
            siz[x] += siz[y];
            if (mx[y] + 1 >= mx[x]) {
                mxx[x] = mx[x];
                mx[x] = mx[y] + 1;
                p[x] = y;
            } else mxx[x] = max(mxx[x], mx[y] + 1);
            d[x] += d[y] + siz[y];
        }
    };
    ll ans = 0, res = 0;
    function<void(int, int, ll, ll, ll, int)> dfs1 = [&](int x, int ff, ll val, ll sizval, ll mch, int op) -> void {
        res += d[x] + val;
        ans += (max(mx[x], mch) + op) * (op ? m : n);
        for (int y : e[x]) {
            if (y == ff) continue;
            ll t = d[x] - d[y] - siz[y];
            ll s = siz[x] - siz[y];
            dfs1(y, x, val + sizval + t + s, sizval + s, max(mch, (p[x] == y ? mxx[x] : mx[x])) + 1, op);
        }
    };
    dfs(1, 1); dfs(n + 1, n + 1);
    dfs1(1, 1, 0, 0, 0, 1); dfs1(n + 1, n + 1, 0, 0, 0, 0);
    res /= 2; //记得除以 2,因为这部分会重复算两边。
    cout << ans + res << "\n";
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    int T = 1;
    //cin >> T;
    while (T--) solve();
    return 0;
}