题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6060

题意:给出一颗n个节点的树,要求将2-n号节点分成k部分,然后再将每一部分加上1号节点,定义每一部分的val为其中的点在原图上的最小斯坦纳树,问总的val最大可能是多少。

解法:官方题解:


#include <bits/stdc++.h>
using namespace std;
const int maxn = 1000010;
typedef long long LL;
struct FastIO
{
    static const int S = 1310720;
    int wpos;
    char wbuf[S];
    FastIO() : wpos(0) {}
    inline int xchar()
    {
        static char buf[S];
        static int len = 0, pos = 0;
        if (pos == len)
            pos = 0, len = fread(buf, 1, S, stdin);
        if (pos == len) exit(0);
        return buf[pos ++];
    }
    inline int xuint()
    {
        int c = xchar(), x = 0;
        while (c <= 32) c = xchar();
        for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0';
        return x;
    }
    inline int xint()
    {
        int s = 1, c = xchar(), x = 0;
        while (c <= 32) c = xchar();
        if (c == '-') s = -1, c = xchar();
        for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0';
        return x * s;
    }
    inline void xstring(char *s)
    {
        int c = xchar();
        while (c <= 32) c = xchar();
        for (; c > 32; c = xchar()) * s++ = c;
        *s = 0;
    }
    inline void wchar(int x)
    {
        if (wpos == S) fwrite(wbuf, 1, S, stdout), wpos = 0;
        wbuf[wpos ++] = x;
    }
    inline void wint(LL x)
    {
        if (x < 0) wchar('-'), x = -x;
        char s[24];
        int n = 0;
        while (x || !n) s[n ++] = '0' + x % 10, x /= 10;
        while (n--) wchar(s[n]);
    }
    inline void wstring(const char *s)
    {
        while (*s) wchar(*s++);
    }
    ~FastIO()
    {
        if (wpos) fwrite(wbuf, 1, wpos, stdout), wpos = 0;
    }
} io;
int head[maxn], edgecnt, n, k;
struct edge{
    int to,len,next;
}E[maxn*2];
void init(){
    memset(head,-1,sizeof(head));
    edgecnt=0;
}
void add(int u, int v, int w){
    E[edgecnt].to = v, E[edgecnt].len = w, E[edgecnt].next = head[u], head[u] = edgecnt++;
}
int sz[maxn];
LL ans;
void dfs(int x, int fa){
    sz[x] = 1;
    for(int i = head[x]; ~i; i=E[i].next){
        int v = E[i].to;
        if(v == fa) continue;
        dfs(v, x);
        ans += (LL)E[i].len * min(sz[v], k);
        sz[x] += sz[v];
    }
}
int main()
{
    while(1)
    {
        n = io.xint();
        k = io.xint();
        init();
        for(int i=1; i<n; i++){
            int u,v,w;
            u = io.xint();
            v = io.xint();
            w = io.xint();
            add(u, v, w);
            add(v, u, w);
        }
        ans = 0;
        dfs(1, -1);
        printf("%lld\n", ans);
    }
    return 0;
}