题目链接:https://ac.nowcoder.com/discuss/434394
#include <bits/stdc++.h>
#define LL long long
using namespace std;
vector<int> v[500005];
int f[500005][2];
void DFS(int u, int fa){
f[u][1]=1;
for(auto x: v[u]){
if(x!=fa){
DFS(x, u);
f[u][1]+=f[x][0];
f[u][0]+=max(f[x][0], f[x][1]);
}
}
}
int main(){
int n, s, x, y; scanf("%d%d", &n, &s);
for(int i=1; i<n; i++){
scanf("%d%d", &x, &y);
v[x].push_back(y);
v[y].push_back(x);
}
DFS(s, 0);
printf("%d\n", f[s][1]);
return 0;
}

京公网安备 11010502036488号