C-嗷嗷嗷嗷嗷_一起来做题~欢乐赛7 (nowcoder.com)

题目描述

给你一棵n个节点的带标号无根树,每个节点都有a[i]个人,每一条边都有边权表示长度。你可以选择任意一个节点为根节点u让其他节点的所有人都聚集到u

定义一个不方便值:所有人走到根节点的最短距离之和,问如何选择根节点能使距离不方便值最小输出最小的不方便值

样例

5 
1 
1 
0 
0 
2 
1 3 1 
2 3 2 
3 4 3 
4 5 3 
15

算法1

(换根dp)
前导
  • 这一题是A题的扩展版
  • 思路和A题一样只要稍微修改一下状态转移方程即可

状态转移方程

表示以u为根节点的子树不方便值的大小

  1. 树形dp部分:

    解释:为边<u,v>的长度

  2. 换根dp部分

    解释:tot为总人数

时间复杂度

参考文献

C++ 代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
// #include <unordered_map>
#include <vector>
#include <queue>
#include <set>
#include <bitset>
#include <cmath>
#include <map>

#define x first
#define y second

#define P 131

#define lc u << 1
#define rc u << 1 | 1

using namespace std;
typedef long long LL;
const int N = 100010;
const LL INF = 0x3f3f3f3f3f3f3f3fll;
int h[N],ne[N * 2],e[N * 2],w[N * 2],idx;
int v[N];
LL f[N];
int sz[N];
int tot;
LL ans;
int n;

void add(int a,int b,int c)
{
    e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx ++;
}

void dfs1(int u,int father)
{
    f[u] = 0;
    sz[u] = v[u];
    for(int i = h[u];~i;i = ne[i])
    {
        int j = e[i];
        if(j == father) continue;
        dfs1(j,u);
        f[u] += f[j] + 1ll * sz[j] * w[i];
        sz[u] += sz[j];
    }
}

void dfs2(int u,int father)
{
    for(int i = h[u];~i;i = ne[i])
    {
        int j = e[i];
        if(j == father) continue;
        f[j] = f[u] - 1ll * sz[j] * w[i] + 1ll * (tot - sz[j]) * w[i];
        dfs2(j,u);
    }
    ans = min(ans,f[u]);
}

void solve()
{
    scanf("%d",&n);
    for(int i = 1;i <= n;i ++) h[i] = -1;
    for(int i = 1;i <= n;i ++) scanf("%d",&v[i]);
    for(int i = 0;i < n - 1;i ++)
    {
        int a,b,c;
        scanf("%d%d%d",&a,&b,&c);
        add(a,b,c);
        add(b,a,c);
    }
    ans = INF;
    dfs1(1,0);
    tot = sz[1];//注意总人数不再是n
    dfs2(1,0);
    printf("%lld\n",ans);
} 

int main()
{
    int _ = 1;

    // freopen("network.in","r",stdin);
    // freopen("network.out","w",stdout);
    // init(N - 1); 

    // std::ios_base::sync_with_stdio(0);
    // cin.tie(0);
    // cin >> _;

    // scanf("%d",&_);
    while(_ --)
    {
        // scanf("%lld%lld",&n,&m);
        solve();
        // test();
    }
    return 0;
}