题目大意

有一个n个点的树,走过每一条边时会减少一定的hp,每条边只能最多经过2次(即每个子树只能进出一次)。
吃了第i个点的苹果会回复a[i]的hp。可以在原地休息,每个单位时间回复1hp。hp不能为负。
从根节点1开始出发,初始hp为0,要求经过所有点后返回1号点,求最少休息时间。

解题思路

这道题显然是用贪心的树形dp来求解。我们可以把边和边的权值都存起来,然后dfs搜索。为了保证能回溯,把找到过的节点存入vector。
接下来考虑一下我们的贪心策略,那么一拍脑袋:我们应该优先走加hp更多,时间t+hp更大的节点。

用官方题解来充实我糟糕的表达.exe
  • 显而易见,我们最方便的操作就是在根节点恢复足够的hp再出发。所以我们要求的答案,就是最少的可以无缝访问所有节点的hp。
  • 我们设访问子树i需要的总hp为, 访问过程中最坏情况需要HP为
    由于访问的子树所需要的总hp-是一个定值,故我们只考虑子树访问顺序对于的影响。
    我们考虑排布子树的访问顺序,得到:

  • 考虑对儿子访问排序,访问顺序应满足相邻交换最优原则。设有相邻对,若不交换更优,则有
  • 我们将min,max转换为简单的逻辑表达式:
    肯定成立,所以我们先对于是否小于0讨论,将儿子的集合分成两部分A,B。其中𝐴为所有满足的所有形成的集合,B为所有满足的所有形成的集合,这样把 𝐴 中的某个元素排在 𝐵 中的另一个元素之前一定满足相邻交换原则。
  • 然后,我们对于A内的顺序考虑,一定有比较式后项成立,只需满足。我们再对于B内的顺序考虑,一定有比较式前项成立,只需满足

AC代码

pair版本

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
pair<long,long> dp[N]; //first为最少时间t,second为当前hp
int a[N];
vector<pair<long,long> > v[N];
bool cmp(pair<long,long> a,pair<long,long> b)
{
    if(a.second-a.first>=0)
    {
        if(b.second-b.first<0) return 1;
        return a.first<b.first;
    }
    else
    {
        if(b.second-b.first>=0) return 0;
        return a.second>b.second;
    }
}
void dfs(int x,int y)
{
    vector<pair<long,long> > vv;
    pair<long,long> n;
    long long z=a[x],m=a[x];
    int i;
    for(i=0;i<v[x].size();i++)
    {
        n=v[x][i];
        if(n.first==y) continue;
        dfs(n.first,x);
        dp[n.first].first+=n.second,dp[n.first].second-=n.second;
        if(dp[n.first].second<0)
            dp[n.first].first-=dp[n.first].second,dp[n.first].second=0;
        vv.push_back(dp[n.first]);
    }
    sort(vv.begin(),vv.end(),cmp);
    for(i=0;i<vv.size();i++)
        n=vv[i],m=min(m,z-n.first),z+=n.second-n.first;
    if(m>=0) dp[x].first=0,dp[x].second=z;
    else dp[x].first=-m,dp[x].second=z-m;
}
int main()
{
    int T,n,x,y,z,i;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        for(i=1;i<=n;i++)
        {
            scanf("%d",&a[i]);
            v[i].clear();
        }
        for(i=1;i<n;i++)
        {
            scanf("%d%d%d",&x,&y,&z);
            v[x].push_back(make_pair(y,z));
            v[y].push_back(make_pair(x,z));
        }
        dfs(1,0);
        printf("%lld\n",dp[1].first);
    }
    return 0;
}

struct版本

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
struct node
{
    long long a,b;
    node(long long x=0,long long y=0) {a=x,b=y;}
} dp[N];
vector<node> v[N];
int a[N];
bool cmp(node a,node b)
{
    if(a.b-a.a>=0)
    {
        if(b.b-b.a<0) return 1;
        return a.a<b.a;
    }
    else
    {
        if(b.b-b.a>=0) return 0;
        return a.b>b.b;
    }
}
void dfs(int x,int y)
{
    vector<node> vv;
    long long z=a[x],m=a[x];
    int i;
    node n;
    for(i=0;i<v[x].size();i++)
    {
        n=v[x][i];
        if(n.a==y)continue;
        dfs(n.a,x);
        dp[n.a].a+=n.b,dp[n.a].b-=n.b;
        if(dp[n.a].b<0) dp[n.a].a-=dp[n.a].b,dp[n.a].b=0;
        vv.push_back(dp[n.a]);
    }
    sort(vv.begin(),vv.end(),cmp);
    for(i=0;i<vv.size();i++)
        n=vv[i],m=min(m,z-n.a),z+=n.b-n.a;
    if(m>=0) dp[x].a=0,dp[x].b=z;
    else dp[x].a=-m,dp[x].b=z-m;
}
int main()
{
    int T,n,x,y,z,i;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        for(i=1;i<=n;i++)
        {
            scanf("%d",&a[i]);
            v[i].clear();
        }
        for(i=1;i<n;i++)
        {
            scanf("%d%d%d",&x,&y,&z);
            v[x].push_back(node(y,z));
            v[y].push_back(node(x,z));
        }
        dfs(1,0);
        printf("%lld\n",dp[1].a);
    }
    return 0;
}