稍微复杂的换根 DP,我能一发 A 掉的还是不多的...
题目大意
给出一棵有 个结点的树,其中 个结点 作特殊标记,令 代表结点 到 简单路径上边数,求有多少个点 ,满足
题解
不妨令 为根。
考虑结点 的最远标记点,可以在 的子树 内,也可以在 的子树外。
对这两种情况分别讨论。
在 的子树内情况比较简单,可以通过一次 dfs 求出。
记 为 内最远标记点到 的距离,如果 内没有标记点,则
但在这个 dfs 过程中,还要维护 内第二远标记点到 的距离 ,原因将在 解释。
在 的子树外情况比较复杂,但可以自然地想到是一个换根 DP:根 号点不存在子树外情况,已经可以求解。
假设 是 的任一儿子,此时对 有两种情况:
- ,即 由 转移而来
- ,即 不是由 转移而来。
以题目样例为例
其中 号点只有一个儿子 号点,那么 号点符合上述的第一种情况,也就是说, 号点的最远点就在孩子 号子树内。
此时,如果从 号点向 号点换根,如果使用 ,则 号点子树外的情况并未得到考虑,所以还需要记录 中所述的 ,并在换根过程中分类讨论。
其他细节请见代码。
#include<bits/stdc++.h> using namespace std; typedef long long LL; template < typename Tp > void read(Tp &x) { x = 0; int fh = 1; char ch = 1; while(ch != '-' && (ch < '0' || ch > '9')) ch = getchar(); if(ch == '-') fh = -1, ch = getchar(); while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); x *= fh; } const int maxn = 100000 + 7; const int maxm = 200000 + 7; const int INF = 0x3f3f3f3f; int n, m, d, dp[maxn], dp2[maxn]; bool mark[maxn]; int Head[maxn], Next[maxm], to[maxm], tot; void addedge(int x, int y) { to[++tot] = y, Next[tot] = Head[x], Head[x] = tot; } void add(int x, int y){ addedge(x, y); addedge(y, x); } void Init(void) { read(n); read(m); read(d); for(int i = 1, x; i <= m; i++) { read(x); mark[x] = true; } for(int i = 1, x, y; i < n; i++) { read(x); read(y); add(x, y); } } void dfs(int x, int fa) { dp[x] = dp2[x] = INF; for(int i = Head[x] ;i ; i = Next[i]) { int y = to[i]; if(y == fa) continue; dfs(y, x); if(dp[y] != INF) { if(dp[x] == INF) dp[x] = dp[y] + 1; else { if(dp[y] + 1 > dp[x]) { dp2[x] = dp[x]; dp[x] = dp[y] + 1; } else if(dp2[x] == INF) dp2[x] = dp[y] + 1; else dp2[x] = max(dp2[x], dp[y] + 1); } } } if(dp[x] == INF) if(mark[x]) dp[x] = 0; } int ans; void dfs2(int x, int fa, int mov) { if(mov == INF && mark[fa]) mov = 1; // printf("id = %d, fa = %d, mov = %d, dp[x] = %d\n", x, fa, mov, dp[x]); if((dp[x] <= d || dp[x] == INF) && (mov <= d || mov == INF)) { ++ans; } for(int i = Head[x]; i; i = Next[i]) { int y = to[i]; if(y == fa) continue; if(dp[x] == dp[y] + 1) { if(mov == INF && dp2[x] != INF) dfs2(y, x, dp2[x] + 1); else if(mov != INF && dp2[x] == INF) dfs2(y, x, mov + 1); else if(mov == INF && dp2[x] == INF) dfs2(y, x, INF); else dfs2(y, x, max(dp2[x], mov) + 1); } else { if(mov == INF && dp[x] != INF) dfs2(y, x, dp[x] + 1); else if(mov != INF && dp[x] == INF) dfs2(y, x, mov + 1); else if(mov == INF && dp[x] == INF) dfs2(y, x, INF); else dfs2(y, x, max(dp[x], mov) + 1); } } } void Work(void) { dfs(1, 0); // for(int i = 1; i <= n; i++) { // printf("%d : %d\n", i, dp[i]); // } dfs2(1, 0, INF); printf("%d\n", ans); } int main(void) { Init(); Work(); return 0; }