稍微复杂的换根 DP,我能一发 A 掉的还是不多的...

题目大意

给出一棵有 个结点的树,其中 个结点 作特殊标记,令 代表结点 简单路径上边数,求有多少个点 ,满足


题解

不妨令 为根。

考虑结点 的最远标记点,可以在 的子树 内,也可以在 的子树外。

对这两种情况分别讨论。

的子树内情况比较简单,可以通过一次 dfs 求出。

内最远标记点到 的距离,如果 内没有标记点,则

但在这个 dfs 过程中,还要维护 内第二远标记点到 的距离 ,原因将在 解释。

的子树外情况比较复杂,但可以自然地想到是一个换根 DP:根 号点不存在子树外情况,已经可以求解。

假设 的任一儿子,此时对 有两种情况:

  • ,即 转移而来
  • ,即 不是由 转移而来。

以题目样例为例

其中 号点只有一个儿子 号点,那么 号点符合上述的第一种情况,也就是说, 号点的最远点就在孩子 号子树内。

此时,如果从 号点向 号点换根,如果使用 ,则 号点子树外的情况并未得到考虑,所以还需要记录 中所述的 ,并在换根过程中分类讨论。

其他细节请见代码。


#include<bits/stdc++.h>
using namespace std;

typedef long long LL;

template < typename Tp >
void read(Tp &x) {
    x = 0; int fh = 1; char ch = 1;
    while(ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if(ch == '-') fh = -1, ch = getchar();
    while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
    x *= fh;
}

const int maxn = 100000 + 7;
const int maxm = 200000 + 7;
const int INF = 0x3f3f3f3f;

int n, m, d, dp[maxn], dp2[maxn];
bool mark[maxn];

int Head[maxn], Next[maxm], to[maxm], tot;
void addedge(int x, int y) {
    to[++tot] = y, Next[tot] = Head[x], Head[x] = tot;
}
void add(int x, int y){
    addedge(x, y); addedge(y, x);
}

void Init(void) {
    read(n); read(m); read(d);
    for(int i = 1, x; i <= m; i++) {
        read(x); mark[x] = true;
    }
    for(int i = 1, x, y; i < n; i++) {
        read(x); read(y);
        add(x, y);
    }
}

void dfs(int x, int fa) {
    dp[x] = dp2[x] = INF;
    for(int i = Head[x] ;i ; i = Next[i]) {
        int y = to[i];
        if(y == fa) continue;
        dfs(y, x);
        if(dp[y] != INF) {
            if(dp[x] == INF) dp[x] = dp[y] + 1;
            else {
                if(dp[y] + 1 > dp[x]) {
                    dp2[x] = dp[x];
                    dp[x] = dp[y] + 1;
                }
                else if(dp2[x] == INF) dp2[x] = dp[y] + 1;
                else dp2[x] = max(dp2[x], dp[y] + 1);
            }
        }
    }
    if(dp[x] == INF) if(mark[x]) dp[x] = 0;
}

int ans;

void dfs2(int x, int fa, int mov) {
    if(mov == INF && mark[fa]) mov = 1;
//    printf("id = %d, fa = %d, mov = %d, dp[x] = %d\n", x, fa, mov, dp[x]);
    if((dp[x] <= d || dp[x] == INF) && (mov <= d || mov == INF)) {
        ++ans;
    }
    for(int i = Head[x]; i; i = Next[i]) {
        int y = to[i];
        if(y == fa) continue;
        if(dp[x] == dp[y] + 1) {
            if(mov == INF && dp2[x] != INF) dfs2(y, x, dp2[x] + 1);
            else if(mov != INF && dp2[x] == INF) dfs2(y, x, mov + 1);
            else if(mov == INF && dp2[x] == INF) dfs2(y, x, INF);
            else dfs2(y, x, max(dp2[x], mov) + 1);
        }
        else {
            if(mov == INF && dp[x] != INF) dfs2(y, x, dp[x] + 1);
            else if(mov != INF && dp[x] == INF) dfs2(y, x, mov + 1);
            else if(mov == INF && dp[x] == INF) dfs2(y, x, INF);
            else dfs2(y, x, max(dp[x], mov) + 1);
        }
    }
}

void Work(void) {
    dfs(1, 0);
//    for(int i = 1; i <= n; i++) {
//        printf("%d : %d\n", i, dp[i]);
//    }
    dfs2(1, 0, INF);
    printf("%d\n", ans);
}

int main(void) {
    Init();
    Work();
    return 0;
}