牛客练习赛64-B
题目描述:
给出一颗n个点n−1条边的树,点的编号为1,2,...,n−1,n,对于每个点i(1<=i<=n),输出与点i距离为2的点的个数。
两个点的距离定义为两个点最短路径上的边的条数。
输入描述:
第一行一个正整数n。
接下来n−1行每行两个正整数u,v表示点u,v之间有一条边。
输出描述:
输入共n行,第i行输出一个整数表示与点i距离为2的点的个数
输入
4
1 2
2 3
3 4
输出
1
1
1
1
思路
简单的树形dp
dp[u][1]表示以u为根的子树与u距离为1的点的数量,也就是u的儿子的数量
dp[u][1]表示以u为根的子树与u距离为2的点的数量,也就是u的孙子的数量
转移方程很好想 每次搜索u-v时 dp[u][2]+=dp[v][1]; dp[u][1]++;
对于u除了孩子这些距离为2,还有可能是fa[u]的孩子数量-1,还有可能就是fa[fa[u]]->u的距离也为2
综上
ans[i]=dp[i][2];
如果存在父亲 ans[i]+=dp[fa[i]][1]-1;
如果存在爷爷 ans[i]++;
最后输出即可
#pragma GCC optimize(3,"Ofast","inline") //G++ #include<bits/stdc++.h> #define mem(a,x) memset(a,x,sizeof(a)) #define debug(x) cout << #x << ": " << x << endl; #define ios ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define fcout cout<<setprecision(4)<<fixed using namespace std; typedef long long ll; //====================================== namespace FastIO{ char print_f[105];void read() {}void print() {putchar('\n');} template <typename T, typename... T2> inline void read(T &x, T2 &... oth){x = 0;char ch = getchar();ll f = 1;while (!isdigit(ch)){if (ch == '-')f *= -1;ch = getchar();}while (isdigit(ch)){x = x * 10 + ch - 48;ch = getchar();}x *= f;read(oth...);} template <typename T, typename... T2> inline void print(T x, T2... oth){ll p3=-1;if(x<0) putchar('-'),x=-x;do{print_f[++p3] = x%10 + 48;}while(x/=10);while(p3>=0) putchar(print_f[p3--]);putchar(' ');print(oth...);}} // namespace FastIO using FastIO::print; using FastIO::read; //====================================== typedef pair<int,int> pii; const int inf=0x3f3f3f3f; const int mod=1e9+7; const int maxn = 1e6+5; int dp[maxn][3]; int fa[maxn]; vector<int>edge[maxn]; void add(int x,int y){ edge[x].push_back(y); edge[y].push_back(x); } void dfs(int u,int fr){ fa[u]=fr; for(auto v:edge[u]){ if(v==fr) continue; dfs(v,u); dp[u][2]+=dp[v][1]; dp[u][1]++; } } int ans[maxn]; int main() { #ifndef ONLINE_JUDGE // freopen("H:\\code\\in.in", "r", stdin); // freopen("H:\\code\\out.out", "w", stdout); clock_t c1 = clock(); #endif //************************************** int n; read(n); for(int i=1;i<n;i++){ int x,y; read(x,y); add(x,y); } dfs(1,0); for(int i=1;i<=n;i++){ ans[i]=dp[i][2]; if(fa[i]) ans[i]+=dp[fa[i]][1]-1; if(fa[fa[i]]) ans[i]++; } for(int i=1;i<=n;i++){ print(ans[i]); } //************************************** #ifndef ONLINE_JUDGE cerr << "Time:" << clock() - c1 << "ms" << endl; #endif return 0; }