题目大意
有一个n个点的树,走过每一条边时会减少一定的hp,每条边只能最多经过2次(即每个子树只能进出一次)。
吃了第i个点的苹果会回复a[i]的hp。可以在原地休息,每个单位时间回复1hp。hp不能为负。
从根节点1开始出发,初始hp为0,要求经过所有点后返回1号点,求最少休息时间。
解题思路
这道题显然是用贪心的树形dp来求解。我们可以把边和边的权值都存起来,然后dfs搜索。为了保证能回溯,把找到过的节点存入vector。
接下来考虑一下我们的贪心策略,那么一拍脑袋:我们应该优先走加hp更多,时间t+hp更大的节点。
用官方题解来充实我糟糕的表达.exe
- 显而易见,我们最方便的操作就是在根节点恢复足够的hp再出发。所以我们要求的答案,就是最少的可以无缝访问所有节点的hp。
- 我们设访问子树i需要的总hp为, 访问过程中最坏情况需要HP为。
由于访问的子树所需要的总hp-是一个定值,故我们只考虑子树访问顺序对于的影响。
我们考虑排布子树的访问顺序,得到:
- 考虑对儿子访问排序,访问顺序应满足相邻交换最优原则。设有相邻对,若不交换更优,则有。
- 我们将min,max转换为简单的逻辑表达式:
若且肯定成立,所以我们先对于是否小于0讨论,将儿子的集合分成两部分A,B。其中𝐴为所有满足的所有形成的集合,B为所有满足的所有形成的集合,这样把 𝐴 中的某个元素排在 𝐵 中的另一个元素之前一定满足相邻交换原则。 - 然后,我们对于A内的顺序考虑,一定有比较式后项成立,只需满足。我们再对于B内的顺序考虑,一定有比较式前项成立,只需满足。
AC代码
pair版本
#include<bits/stdc++.h> using namespace std; const int N=1e5+10; pair<long,long> dp[N]; //first为最少时间t,second为当前hp int a[N]; vector<pair<long,long> > v[N]; bool cmp(pair<long,long> a,pair<long,long> b) { if(a.second-a.first>=0) { if(b.second-b.first<0) return 1; return a.first<b.first; } else { if(b.second-b.first>=0) return 0; return a.second>b.second; } } void dfs(int x,int y) { vector<pair<long,long> > vv; pair<long,long> n; long long z=a[x],m=a[x]; int i; for(i=0;i<v[x].size();i++) { n=v[x][i]; if(n.first==y) continue; dfs(n.first,x); dp[n.first].first+=n.second,dp[n.first].second-=n.second; if(dp[n.first].second<0) dp[n.first].first-=dp[n.first].second,dp[n.first].second=0; vv.push_back(dp[n.first]); } sort(vv.begin(),vv.end(),cmp); for(i=0;i<vv.size();i++) n=vv[i],m=min(m,z-n.first),z+=n.second-n.first; if(m>=0) dp[x].first=0,dp[x].second=z; else dp[x].first=-m,dp[x].second=z-m; } int main() { int T,n,x,y,z,i; scanf("%d",&T); while(T--) { scanf("%d",&n); for(i=1;i<=n;i++) { scanf("%d",&a[i]); v[i].clear(); } for(i=1;i<n;i++) { scanf("%d%d%d",&x,&y,&z); v[x].push_back(make_pair(y,z)); v[y].push_back(make_pair(x,z)); } dfs(1,0); printf("%lld\n",dp[1].first); } return 0; }
struct版本
#include<bits/stdc++.h> using namespace std; const int N=1e5+10; struct node { long long a,b; node(long long x=0,long long y=0) {a=x,b=y;} } dp[N]; vector<node> v[N]; int a[N]; bool cmp(node a,node b) { if(a.b-a.a>=0) { if(b.b-b.a<0) return 1; return a.a<b.a; } else { if(b.b-b.a>=0) return 0; return a.b>b.b; } } void dfs(int x,int y) { vector<node> vv; long long z=a[x],m=a[x]; int i; node n; for(i=0;i<v[x].size();i++) { n=v[x][i]; if(n.a==y)continue; dfs(n.a,x); dp[n.a].a+=n.b,dp[n.a].b-=n.b; if(dp[n.a].b<0) dp[n.a].a-=dp[n.a].b,dp[n.a].b=0; vv.push_back(dp[n.a]); } sort(vv.begin(),vv.end(),cmp); for(i=0;i<vv.size();i++) n=vv[i],m=min(m,z-n.a),z+=n.b-n.a; if(m>=0) dp[x].a=0,dp[x].b=z; else dp[x].a=-m,dp[x].b=z-m; } int main() { int T,n,x,y,z,i; scanf("%d",&T); while(T--) { scanf("%d",&n); for(i=1;i<=n;i++) { scanf("%d",&a[i]); v[i].clear(); } for(i=1;i<n;i++) { scanf("%d%d%d",&x,&y,&z); v[x].push_back(node(y,z)); v[y].push_back(node(x,z)); } dfs(1,0); printf("%lld\n",dp[1].a); } return 0; }