题目
Code
#include<bits/stdc++.h>
using namespace std;
const int N=200002;
struct node{
int to,ne,w;
}e[N<<1];
int h[N],sum,mx[N],sz[N],dis[N],rem[N],tot,x,y,z,i,k,rt,n,m,q[N],ans,sec[N],d[N],jud[1000002],num;
bool vis[N];
inline char gc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int rd(){
int x=0,fl=1;char ch=gc();
for (;ch<48||ch>57;ch=gc())if(ch=='-')fl=-1;
for (;48<=ch&&ch<=57;ch=gc())x=(x<<3)+(x<<1)+(ch^48);
return x*fl;
}
void add(int x,int y,int z){e[++tot]=(node){y,h[x],z},h[x]=tot;}
void getrt(int u,int fa){
sz[u]=1,mx[u]=0;
for (int i=h[u],v;i;i=e[i].ne)
if ((v=e[i].to)!=fa && !vis[v]){
getrt(v,u);
sz[u]+=sz[v];
mx[u]=max(mx[u],sz[v]);
}
mx[u]=max(mx[u],sum-sz[u]);
if (mx[u]<mx[rt]) rt=u;
}
void getdis(int u,int fa){
if (dis[u]<=k) rem[++num]=dis[u],sec[num]=d[u];
for (int i=h[u],v;i;i=e[i].ne)
if ((v=e[i].to)!=fa && !vis[v]){
d[v]=d[u]+1;
dis[v]=dis[u]+e[i].w;
getdis(v,u);
}
}
void calc(int u){
int p=0;jud[0]=0;
for (int i=h[u],v;i;i=e[i].ne)
if (!vis[v=e[i].to]){
num=0,dis[v]=e[i].w,d[v]=1;
getdis(v,0);
for (int j=1;j<=num;j++) ans=min(ans,sec[j]+jud[k-rem[j]]);
for (int j=1;j<=num;j++) q[++p]=rem[j],jud[rem[j]]=min(jud[rem[j]],sec[j]);
}
for (int i=1;i<=p;i++) jud[q[i]]=1e9;
}
void solve(int u){
vis[u]=1,calc(u);
for (int i=h[u],v;i;i=e[i].ne)
if (!vis[v=e[i].to]){
sum=sz[v],mx[rt=0]=1e9;
getrt(v,0);
solve(rt);
}
}
int main(){
n=rd(),k=rd();
memset(jud,63,sizeof(jud));
for (i=1;i<n;i++) x=rd()+1,y=rd()+1,z=rd(),add(x,y,z),add(y,x,z);
ans=1e9;
sum=mx[rt=0]=n;
getrt(1,0);
solve(rt);
printf("%d",ans==1e9?-1:ans);
}