主要是为了学习一波 wqs 二分。

首先发现这个题目等价于把树切成 k+1 块 每块内选个直径加起来,也就等价于在树上找 条不相交的链使得收益最大。

考虑树上每一条链都可以拆成两条深度单调的链,所以我们可以设 表示现在在 点 已经选择了 个链 当前 点的状态是(没有链经过/有一条单链待匹配/有链经过),转移一下就可以了,复杂度

对于这一类“恰好选 个” 的问题,如果优化不掉 这一维,就可以考虑 wqs 二分。

我们设 表示选 个的最优值,首先发现 是凸的(导函数单调),也就是平面上若干个点 构成了一个凸壳。上面的暴力 dp 就是求出凸包上所有点的过程,但我们实际上只想求出题目要求的点 。我们考虑我们二分一条斜率固定的直线 去切这个凸包。这个时候这条直线切到的点 满足 ,回忆凸壳的定义我们发现 应当取最大值,我们求出 在什么时候取到最大值就可以求出 ,然后根据导函数的单调性来决定斜率的增减。这里求最大值可以看成每个物品有额外的代价/收益,然后扔掉数量的限制去 dp 最大收益的数量。

但是这样会有一个问题:让 C 取最大值的点可能是不唯一的,也就是我们二分最后得到的结果不一定是一个点,而是一条直线上的许多点。这时候我们观察 发现 都是固定的,我们只需要让 尽量大就可以取到最大值了。(根据题目会不同)

总结做题思路:优化不下去 暴力 dp 打表看差分表 发现凸函数,用 wqs 二分。

#include <algorithm>
#include <iostream>
#include <cstring>
#include <climits>
#include <cstdlib>
#include <cstdio>
#include <bitset>
#include <vector>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <map>
#include <set>

#define fi first
#define se second
#define U unsigned
#define P std::pair
#define LL long long
#define pb push_back
#define MP std::make_pair
#define all(x) x.begin(),x.end()
#define CLR(i,a) memset(i,a,sizeof(i))
#define FOR(i,a,b) for(int i = a;i <= b;++i)
#define ROF(i,a,b) for(int i = a;i >= b;--i)
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl

const int MAXN = 3e5 + 5;

struct Edge{
    int to,w,nxt;
}e[MAXN<<1];
int head[MAXN],cnt;

inline void add(int u,int v,int w){
    e[++cnt] = (Edge){v,w,head[u]};head[u] = cnt;
    e[++cnt] = (Edge){u,w,head[v]};head[v] = cnt;
}

int n,k;

struct Node{
    LL f,g;

    Node(LL f=0,LL g=0) : f(f),g(g) {}

    inline Node operator + (const Node &t) const {
        return Node(f+t.f,g+t.g);
    }

    inline Node operator + (const LL &t) const {
        return Node(f+t,g);
    }

    inline bool operator < (const Node &t) const {
        return f == t.f ? g < t.g : f < t.f;
    }
}F[MAXN][3];
LL ext;// 花费

inline Node upd(const Node &t){
    return Node(t.f-ext,t.g+1);
}

inline void dfs(int v,int fa=0){
    F[v][1] = F[v][2] = F[v][0] = Node(0,0);
    F[v][2] = std::max(F[v][2],Node(-ext,1));
    for(int i = head[v];i;i = e[i].nxt){
        if(e[i].to == fa) continue;
        dfs(e[i].to,v);
        F[v][2] = std::max(F[v][2]+F[e[i].to][0],upd(F[v][1]+F[e[i].to][1]+e[i].w));
        F[v][1] = std::max(F[v][1]+F[e[i].to][0],F[v][0]+F[e[i].to][1]+e[i].w);
        F[v][0] = F[v][0]+F[e[i].to][0];
    }
    F[v][0] = std::max(F[v][0],std::max(F[v][2],upd(F[v][1])));
}

inline int chk(LL x){
    ext = x;
    dfs(1);return F[1][0].g;
}

int main(){
    scanf("%d%d",&n,&k);++k;
    LL l=0,r=0;
    FOR(i,2,n){
        int u,v,w;scanf("%d%d%d",&u,&v,&w);add(u,v,w);
        r += std::abs(w);
    }
    l = -r;
    LL ans;
    while(l <= r){
        LL mid = (l + r) >> 1;
        if(chk(mid) >= k) ans = mid,l = mid+1;
        else r = mid-1;
    }
    chk(ans);
    printf("%lld\n",F[1][0].f+1ll*k*ans);
    return 0;
}