题面:
题意:
给定一棵n个点的树,每条边有两个取值 a[i],b[i]。
给定一个正整数 k,k≤20,表示这棵树有k条边选 a[i],其余的边选 b[i]。
问在上述所有的情况当中,树的直径最短是多少。
题解:
我们二分一个树的直径,将上述问题转化为判定树的直径为 mid 时能不能成立。
我们设 dp[i][j]表示在以 i 为根的子树中,有 j 条边的边权来自于 a 数组,且以 i 为根的子树中任意两点之间的距离都 ≤mid,且 i 的子树中与 i 点距离最远的点到 i 的距离的最小可能值。
那么在一个节点合并两个儿子的时候判断一下合并完成的最长路径(子树直径)是否小于等于 mid , 小于等于 mid 就转移,不然就不转移,最后若 dp[root][k] 不等于正无穷,就说整棵子树里存在这么一种方案, mid 就可行,反之不行。
官方题解:
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#include<list>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
//#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=998244353;
const double eps=1e-1;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=20100;
const int maxp=1100;
const int maxm=500100;
const int up=100000;
int head[maxn],ver[maxn<<1],nt[maxn<<1],tot=1;
ll edgea[maxn<<1],edgeb[maxn<<1],dp[maxn][21],f[21],mid;
int si[maxn];
int n,k,x,y,a,b;
void add(int x,int y,int z1,int z2)
{
ver[++tot]=y,edgea[tot]=z1,edgeb[tot]=z2;
nt[tot]=head[x],head[x]=tot;
}
void init(int n)
{
for(int i=1;i<=n;i++)
head[i]=0;
tot=1;
}
void dfs(int x,int fa)
{
si[x]=dp[x][0]=0;
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y==fa) continue;
dfs(y,x);
ll a=edgea[i],b=edgeb[i];
int six=min(si[x],k),siy=min(si[y],k),nowsix=min(si[x]+si[y]+1,k);
si[x]+=si[y]+1;
for(int j=0;j<=nowsix;j++) f[j]=lnf;
for(int j=0;j<=six;j++)
{
for(int l=0;l<=siy&&l+j<=nowsix;l++)
{
if(j+l<nowsix&&dp[x][j]+dp[y][l]+a<=mid) f[j+l+1]=min(f[j+l+1],max(dp[x][j],dp[y][l]+a));
if(dp[x][j]+dp[y][l]+b<=mid) f[j+l]=min(f[j+l],max(dp[x][j],dp[y][l]+b));
}
}
for(int j=0;j<=nowsix;j++)
dp[x][j]=f[j];
}
}
int main(void)
{
int tt;
scanf("%d",&tt);
while(tt--)
{
ll l=0,r=0;
scanf("%d%d",&n,&k);
init(n);
for(int i=1;i<n;i++)
{
scanf("%d%d%d%d",&x,&y,&a,&b);
add(x,y,a,b);
add(y,x,a,b);
r=r+max(a,b);
}
ll ans=0;
while(l<=r)
{
mid=(l+r)>>1;
dfs(1,0);
if(dp[1][k]<lnf) ans=mid,r=mid-1;
else l=mid+1;
}
printf("%lld\n",ans);
}
return 0;
}