主要是为了学习一波 wqs 二分。
首先发现这个题目等价于把树切成 k+1 块 每块内选个直径加起来,也就等价于在树上找 条不相交的链使得收益最大。
考虑树上每一条链都可以拆成两条深度单调的链,所以我们可以设 表示现在在 点 已经选择了 个链 当前 点的状态是(没有链经过/有一条单链待匹配/有链经过),转移一下就可以了,复杂度
对于这一类“恰好选 个” 的问题,如果优化不掉 这一维,就可以考虑 wqs 二分。
我们设 表示选 个的最优值,首先发现 是凸的(导函数单调),也就是平面上若干个点 构成了一个凸壳。上面的暴力 dp 就是求出凸包上所有点的过程,但我们实际上只想求出题目要求的点 。我们考虑我们二分一条斜率固定的直线 去切这个凸包。这个时候这条直线切到的点 满足 ,回忆凸壳的定义我们发现 应当取最大值,我们求出 在什么时候取到最大值就可以求出 ,然后根据导函数的单调性来决定斜率的增减。这里求最大值可以看成每个物品有额外的代价/收益,然后扔掉数量的限制去 dp 最大收益的数量。
但是这样会有一个问题:让 C 取最大值的点可能是不唯一的,也就是我们二分最后得到的结果不一定是一个点,而是一条直线上的许多点。这时候我们观察 发现 都是固定的,我们只需要让 尽量大就可以取到最大值了。(根据题目会不同)
总结做题思路:优化不下去 暴力 dp 打表看差分表 发现凸函数,用 wqs 二分。
#include <algorithm> #include <iostream> #include <cstring> #include <climits> #include <cstdlib> #include <cstdio> #include <bitset> #include <vector> #include <cmath> #include <ctime> #include <queue> #include <stack> #include <map> #include <set> #define fi first #define se second #define U unsigned #define P std::pair #define LL long long #define pb push_back #define MP std::make_pair #define all(x) x.begin(),x.end() #define CLR(i,a) memset(i,a,sizeof(i)) #define FOR(i,a,b) for(int i = a;i <= b;++i) #define ROF(i,a,b) for(int i = a;i >= b;--i) #define DEBUG(x) std::cerr << #x << '=' << x << std::endl const int MAXN = 3e5 + 5; struct Edge{ int to,w,nxt; }e[MAXN<<1]; int head[MAXN],cnt; inline void add(int u,int v,int w){ e[++cnt] = (Edge){v,w,head[u]};head[u] = cnt; e[++cnt] = (Edge){u,w,head[v]};head[v] = cnt; } int n,k; struct Node{ LL f,g; Node(LL f=0,LL g=0) : f(f),g(g) {} inline Node operator + (const Node &t) const { return Node(f+t.f,g+t.g); } inline Node operator + (const LL &t) const { return Node(f+t,g); } inline bool operator < (const Node &t) const { return f == t.f ? g < t.g : f < t.f; } }F[MAXN][3]; LL ext;// 花费 inline Node upd(const Node &t){ return Node(t.f-ext,t.g+1); } inline void dfs(int v,int fa=0){ F[v][1] = F[v][2] = F[v][0] = Node(0,0); F[v][2] = std::max(F[v][2],Node(-ext,1)); for(int i = head[v];i;i = e[i].nxt){ if(e[i].to == fa) continue; dfs(e[i].to,v); F[v][2] = std::max(F[v][2]+F[e[i].to][0],upd(F[v][1]+F[e[i].to][1]+e[i].w)); F[v][1] = std::max(F[v][1]+F[e[i].to][0],F[v][0]+F[e[i].to][1]+e[i].w); F[v][0] = F[v][0]+F[e[i].to][0]; } F[v][0] = std::max(F[v][0],std::max(F[v][2],upd(F[v][1]))); } inline int chk(LL x){ ext = x; dfs(1);return F[1][0].g; } int main(){ scanf("%d%d",&n,&k);++k; LL l=0,r=0; FOR(i,2,n){ int u,v,w;scanf("%d%d%d",&u,&v,&w);add(u,v,w); r += std::abs(w); } l = -r; LL ans; while(l <= r){ LL mid = (l + r) >> 1; if(chk(mid) >= k) ans = mid,l = mid+1; else r = mid-1; } chk(ans); printf("%lld\n",F[1][0].f+1ll*k*ans); return 0; }