题目大意:
求两条树上边不相交路径, 使得删掉这两条路径上的点以后剩下的连通块数量最多。
做法
两条路径在树上大概会长成两个倒着的‘V’字形, 考虑在某一个'V'的最上面那个点统计答案。
于是统计答案的时候我们发现有几种统计法:
- 父节点有一个'V', 子节点有一个'V‘。
- 父节点上一条链到子节点, 子节点有一条链经过, 内部有一条路径, 合并两条链。
- 父节点被一条链经过, 子树内有一条路径, 子节点有一条链经过, 内部有一条路径。
发现当前点是否被经过了对于转移是有关系的, 所以设状态的时候不能够简单的设子树内选一条路径的方案数目。
设\(dp(x, 0)\)表示切除一条一个端点在\(x\)上的链的最大连通块数。
设\(dp(x, 1)\)表示切除一条\(x\)的子树内的不经过\(x\)的路径的最大连通块数。
设\(dp(x, 2)\)表示切除一条\(x\)的子树内过\(x\)的路径的最大连通块数。
设\(dp(x, 3)\)表示切除一条端点在\(x\)上的链和子树内一条路径的最大连通块数目。
以\(1\)为根做\(dp\)。
先考虑如何计算答案。
-
\(dp(x, 3) + dp(y, 0) - 1 + [x \ != \ 1]\)
-
\(dp(x, 0) + dp(y, 3) - 1 + [x\ != \ 1]\)
这里\(-1\)是因为\(x\)的一个儿子变成了链, 原先计算的连通块的点会被删掉。
然后再加上父节点形成的连通块来计算。
- \(dp(x, 1) + dp(y, 2)\)
- \(dp(x, 2) + dp(y, 1) - 1 + [x\ != \ 1]\)
- \(dp(x, 2) +dp(y, 2) - 1 + [x\ != \ 1]\)
这里要减去\(1\)是因为\(x\)被断开以后, 会把\(x\)的所有子节点都计算成连通块, 会重复算一个。
- $ dp(x, 1) + dp(y, 1) - 1 $
这个也是去重。
然后考虑\(dp\)数组。
$dp(x, 0) = max { dp(x, 0) ,dp(y, 0) + deg(x) - 1 } $。
很显然吧。。。
\(dp(x, 1) = max \{dp(x, 1), dp(y, 1), dp(y, 2) + 1\}\)
\(dp(x, 2) = max\{dp(x, 2), dp(x, 0) +dp(y, 0) - 1\}\)
还有一个转移就是从当前儿子里选一条链, 和前面的路径拼起来。
代码
#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
#define R register
const int N = 1e5 + 10;
inline int read() {
int x = 0, f = 1; char a = getchar();
for(; a > '9' || a < '0'; a = getchar()) if(a == '-') f = -1;
for(; a >= '0' && a <= '9'; a = getchar()) x = x * 10 + a - '0';
return x * f;
}
struct edge {
int to, next;
}e[N << 1];
int cnt, head[N];
inline void add(int x, int y) {
e[++ cnt] = {y, head[x]}; head[x] = cnt;
}
inline void Max(int &x, int y) { x = x < y ? y : x; }
int n, type, p0, p1, h0, h1;
int deg[N], Ans;
int dp[N][4];
inline void dfs(int x, int fx) {
dp[x][0] = dp[x][2] = dp[x][3] = deg[x];
dp[x][1] = 1;
int res = 0;
for(R int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if(y == fx) continue;
dfs(y, x);
Max(Ans, dp[x][3] + dp[y][0] - (x == 1));
Max(Ans, dp[x][0] + dp[y][3] - (x == 1));
Max(Ans, dp[x][1] + dp[y][2]);
Max(Ans, dp[x][2] + dp[y][1] - 1);
Max(Ans, dp[x][1] + dp[y][1] - 1);
Max(Ans, dp[x][2] + dp[y][1] - (x == 1));
Max(Ans, dp[x][2] + dp[y][2] - (x == 1));
Max(dp[x][1], dp[y][1]);
Max(dp[x][1], dp[y][2] + 1);
Max(dp[x][3], dp[x][0] + dp[y][2] - 1);
Max(dp[x][3], dp[x][0] + dp[y][1] - 1);
Max(dp[x][3], dp[x][2] + dp[y][0] - 1);
Max(dp[x][3], dp[y][3] + deg[x] - 1);
Max(dp[x][3], dp[y][0] + deg[x] + res - 2);
Max(dp[x][2], dp[x][0] + dp[y][0] - 1);
Max(dp[x][0], dp[y][0] + deg[x] - 1);
Max(dp[x][2], dp[x][0]);
Max(dp[x][3], dp[x][2]);
Max(res, dp[y][1]);
Max(res, dp[y][2]);
}
}
inline void solve() {
n = read();
if(type == 1) p0 = read(), p1 = read();
if(type == 2) p0 = read(), p1 = read(), h0 = read(), h1 = read();
cnt = 0; Ans = 0;
for(R int i = 1; i <= n; i ++) head[i] = 0, deg[i] = 0;
for(R int i = 1; i < n; i ++) {
int x = read(), y = read();
add(x, y); add(y, x);
deg[x] ++; deg[y] ++;
}
for(R int i = 2; i <= n; i ++) deg[i] --;
dfs(1, 0);
printf("%d\n", Ans);
}
int main() {
#ifdef IN
//freopen(".in", "r", stdin);
//freopen(".out", "w", stdout);
#endif
int T = read();
type = read();
while(T --) solve();
return 0;
}