题目大意:

求两条树上边不相交路径, 使得删掉这两条路径上的点以后剩下的连通块数量最多。

做法

两条路径在树上大概会长成两个倒着的‘V’字形, 考虑在某一个'V'的最上面那个点统计答案。

于是统计答案的时候我们发现有几种统计法:

  • 父节点有一个'V', 子节点有一个'V‘。
  • 父节点上一条链到子节点, 子节点有一条链经过, 内部有一条路径, 合并两条链。
  • 父节点被一条链经过, 子树内有一条路径, 子节点有一条链经过, 内部有一条路径。

发现当前点是否被经过了对于转移是有关系的, 所以设状态的时候不能够简单的设子树内选一条路径的方案数目。

\(dp(x, 0)\)表示切除一条一个端点在\(x\)上的链的最大连通块数。

\(dp(x, 1)\)表示切除一条\(x\)的子树内的不经过\(x\)的路径的最大连通块数。

\(dp(x, 2)\)表示切除一条\(x\)的子树内过\(x\)的路径的最大连通块数。

\(dp(x, 3)\)表示切除一条端点在\(x\)上的链和子树内一条路径的最大连通块数目。

\(1\)为根做\(dp\)

先考虑如何计算答案。

  • \(dp(x, 3) + dp(y, 0) - 1 + [x \ != \ 1]\)

  • \(dp(x, 0) + dp(y, 3) - 1 + [x\ != \ 1]\)

这里\(-1\)是因为\(x\)的一个儿子变成了链, 原先计算的连通块的点会被删掉。

然后再加上父节点形成的连通块来计算。

  • \(dp(x, 1) + dp(y, 2)\)
  • \(dp(x, 2) + dp(y, 1) - 1 + [x\ != \ 1]\)
  • \(dp(x, 2) +dp(y, 2) - 1 + [x\ != \ 1]\)

这里要减去\(1\)是因为\(x\)被断开以后, 会把\(x\)的所有子节点都计算成连通块, 会重复算一个。

  • $ dp(x, 1) + dp(y, 1) - 1 $

这个也是去重。

然后考虑\(dp\)数组。

$dp(x, 0) = max { dp(x, 0) ,dp(y, 0) + deg(x) - 1 } $。

很显然吧。。。

\(dp(x, 1) = max \{dp(x, 1), dp(y, 1), dp(y, 2) + 1\}\)

\(dp(x, 2) = max\{dp(x, 2), dp(x, 0) +dp(y, 0) - 1\}\)

\[dp(x,3) = max\{dp(x, 3), \\dp(x, 0) + dp(y, 2) - 1, \\dp(x, 0) + dp(y, 1) - 1,\\ dp(x, 2) +dp(y, 0)- 1 ,\\ dp(y, 3) +deg(x) - 1\} \]

还有一个转移就是从当前儿子里选一条链, 和前面的路径拼起来。

代码

#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>

using namespace std;

#define R register
const int N = 1e5 + 10;

inline int read() {
	int x = 0, f = 1; char a = getchar();
	for(; a > '9' || a < '0'; a = getchar()) if(a == '-') f = -1;
	for(; a >= '0' && a <= '9'; a = getchar()) x = x * 10 + a - '0';
	return x * f;
}

struct edge {
	int to, next;
}e[N << 1];
int cnt, head[N];
inline void add(int x, int y) { 
	e[++ cnt] = {y, head[x]}; head[x] = cnt; 
}

inline void Max(int &x, int y) { x = x < y ? y : x; }

int n, type, p0, p1, h0, h1;
int deg[N], Ans;

int dp[N][4];

inline void dfs(int x, int fx) {
	dp[x][0] = dp[x][2] = dp[x][3] = deg[x];
	dp[x][1] = 1;
	int res = 0;
	for(R int i = head[x]; i; i = e[i].next) {
		int y = e[i].to;
		if(y == fx) continue;
		dfs(y, x);
		Max(Ans, dp[x][3] + dp[y][0] - (x == 1));
		Max(Ans, dp[x][0] + dp[y][3] - (x == 1));
		Max(Ans, dp[x][1] + dp[y][2]);
		Max(Ans, dp[x][2] + dp[y][1] - 1);
		Max(Ans, dp[x][1] + dp[y][1] - 1);
		Max(Ans, dp[x][2] + dp[y][1] - (x == 1));
		Max(Ans, dp[x][2] + dp[y][2] - (x == 1));
		Max(dp[x][1], dp[y][1]);
		Max(dp[x][1], dp[y][2] + 1);
		Max(dp[x][3], dp[x][0] + dp[y][2] - 1);
		Max(dp[x][3], dp[x][0] + dp[y][1] - 1);
		Max(dp[x][3], dp[x][2] + dp[y][0] - 1);
		Max(dp[x][3], dp[y][3] + deg[x] - 1);
		Max(dp[x][3], dp[y][0] + deg[x] + res - 2);
		Max(dp[x][2], dp[x][0] + dp[y][0] - 1);
		Max(dp[x][0], dp[y][0] + deg[x] - 1);
		Max(dp[x][2], dp[x][0]);
		Max(dp[x][3], dp[x][2]);
		Max(res, dp[y][1]);
		Max(res, dp[y][2]);
	}
}

inline void solve() {
	n = read(); 
	if(type == 1) p0 = read(), p1 = read();
	if(type == 2) p0 = read(), p1 = read(), h0 = read(), h1 = read(); 
	cnt = 0; Ans = 0;
	for(R int i = 1; i <= n; i ++) head[i] = 0, deg[i] = 0; 
	for(R int i = 1; i < n; i ++) {
		int x = read(), y = read(); 
		add(x, y); add(y, x);
		deg[x] ++; deg[y] ++;
	}
	for(R int i = 2; i <= n; i ++) deg[i] --;
	dfs(1, 0);
	printf("%d\n", Ans);
}

int main() {
	#ifdef IN
	//freopen(".in", "r", stdin);
	//freopen(".out", "w", stdout);
	#endif
	int T = read();
	type = read();
	while(T --) solve();
	return 0;
}