原题链接:https://ac.nowcoder.com/acm/contest/9753/C
同
题意:
给出一个含有n个节点的无权无根树,请问有多少个点在树的直径上?(直径可能不只一条)
题解:
利用dfs和树形dp解决这个问题。
问题分解:
1.首先,如何判断一个节点是否在直径上?
这个节点到其他节点的距离中的最大值及次大值的和 == 树的直径
2.其次,如何找到一个节点到其他节点的最大值及次大值?
最大值及次大值只可能产生于:当前节点的向上搜索以及向下搜索的链。
即当前节点的最长子链、次长子链以及向上搜索最长链这三者之中。
选取三者中前2大的极为最大值以及次大值。
3.最后,如何求得最长子链、次长子链以及向上搜索最长链?
利用dfs以及树形dp。
设定
- dp[i][0]表示第i个节点的最长子链的长度。
- dp[i][1]表示第i个节点的次长子链的长度。
- dp[i][2]表示第i个节点的向上搜索最链的长度。
- myson[i]表示第i个节点最长子链的第一个儿子节点
利用两次dfs求得所有的dp值。
1.在第一次dfs中求出dp[i][0]、dp[i][1]:
状态转移方程: if(dp[son][0]+lenof(son,root)>dp[root][0])//更新最长子链以及次长子链 { dp[root][1]=dp[root][0]; myson[root]=son; dp[root][0]=dp[son][0]+lenof(son,root); } else if(dp[son][0]+lenof(son,root)>dp[root][1])//更新次长子链 { dp[root][1]=dp[son][0]+lenof(son,root); }
2.在第二次dfs中利用已求出的dp[i][0]、dp[i][1]去求解dp[i][2]:
状态转移方程: if(myson[root]==son) //若son在root的最长子链,那么向上搜索最长链可能在root的向上最长链或者root的次长子链上出现 dp[son][2]=max(dp[root][2],dp[root][1])+lenof(son,root); else //若son不在root的最长子链,那么向上搜索最长链可能在root的向上最长链或者root的最长子链上出现 dp[son][2]=max(dp[root][2],dp[root][1])+lenof(son,root);
AC代码:
#include <bits/stdc++.h> class Solution { public: /** * 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可 * * @param n int整型 节点个数 * @param u int整型vector * @param v int整型vector * @return int整型 */ vector<pair<int,int> > vec[100005]; int dp[100005][3],myson[100005]; bool vis[100005]; void dfs1(int root) { for(int i=0;i<(int)vec[root].size();++i) { int son = vec[root][i].first; int len = vec[root][i].second; if(!vis[son]) { vis[son]=true; dfs1(son); if(dp[son][0]+len>dp[root][0]) { dp[root][1]=dp[root][0]; myson[root]=son; dp[root][0]=dp[son][0]+len; } else if(dp[son][0]+len>dp[root][1]) { dp[root][1]=dp[son][0]+len; } vis[son]=false; } } } void dfs2(int root) { for(int i=0;i<(int)vec[root].size();++i) { int son = vec[root][i].first; int len = vec[root][i].second; if(!vis[son]) { if(myson[root]==son) dp[son][2]=max(dp[root][2],dp[root][1])+len; else dp[son][2]=max(dp[root][2],dp[root][0])+len; vis[son]=true; dfs2(son); vis[son]=false; } } } int PointsOnDiameter(int n, vector<int>& u, vector<int>& v) { // write code here int m = u.size(); for(int i=0;i<m;++i) { vec[u[i]].push_back(make_pair(v[i],1)); vec[v[i]].push_back(make_pair(u[i],1)); } for(int i=1;i<=n;++i){ vis[i]=false; dp[i][0]=dp[i][1]=dp[i][2]=0; myson[i]=0; } vis[1]=true; dfs1(1); dfs2(1); vis[1]=false; int mx = 0; for(int i=1;i<=n;++i) mx=max(mx,dp[i][0]+dp[i][2]); int ans = 0; for(int i=1;i<=n;++i) { if(dp[i][0]+dp[i][1]+dp[i][2]-min(dp[i][0],min(dp[i][1],dp[i][2]))==mx) ans++; } return ans; } };
欢迎指正和评论!