题目链接:https://www.luogu.com.cn/problem/P4149
题目大意:
我们用个s[i]统计之前子树距离为i的最少边数,合并换根就行了。
#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int N=200020, inf=1e9+1000;
struct edg{
int to, w, next;
}e[N*4];
int tot, root, ans, allnode, n, k;
int head[N], vis[N], siz[N], d[N];
struct deep{
int s;
int d;
}deep[N];
int f[N];//求重心的最多子节点个数
int q[N];//mem辅助数组
int s[5000005];
void add(int u, int v, int w){
e[tot].to=v, e[tot].next=head[u];
e[tot].w=w, head[u]=tot++;
}
void getroot(int u, int fa){//求重心
siz[u]=1;
f[u]=0;
for(int i=head[u]; i!=-1; i=e[i].next){
int to=e[i].to;
if(to==fa||vis[to]){
continue;
}
getroot(to, u);
siz[u]+=siz[to];
f[u]=max(f[u], siz[to]);
}
f[u]=max(allnode-siz[u], f[u]);
if(f[u]<f[root]){
root=u;
}
}
void getdeep(int u, int fa, int ds){//获取子树所有节点与根的距离
deep[++deep[0].s].s=d[u]; deep[deep[0].s].d=ds;
for(int i=head[u]; i!=-1; i=e[i].next){
int to=e[i].to;
if(to==fa||vis[to]){
continue;
}
d[to]=d[u]+e[i].w;
getdeep(to, u, ds+1);
}
}
void cal(int u){//计算当前以重心x的子树下,所有情况的答案
int Len=0;
d[u]=0; s[0]=0;
for(int i=head[u]; i!=-1; i=e[i].next){
int to=e[i].to;
if(vis[to]){
continue;
}
deep[0].s=0;
d[to]=e[i].w;
getdeep(to, u, 1);
for(int i=1; i<=deep[0].s; i++){//统计
q[++Len]=deep[i].s;
if(k>=deep[i].s){
if(s[k-deep[i].s]!=-1){//如果之前存在k-deep[i].s
if(s[k]==-1){
s[k]=s[k-deep[i].s]+deep[i].d;
}
else{
s[k]=min(s[k], s[k-deep[i].s]+deep[i].d);
}
}
}
}
for(int i=1; i<=deep[0].s; i++){
if(deep[i].s>k){
continue;
}
if(s[deep[i].s]==-1){
s[deep[i].s]=deep[i].d;
}
else{
s[deep[i].s]=min(s[deep[i].s], deep[i].d);
}
}
}
if(s[k]!=-1&&s[k]<ans){
ans=s[k];
}
s[k]=-1;
for(int i=1; i<=Len; i++){
s[q[i]]=-1;
}
}
void work(int u){//以x为重心进行计算
vis[u]=1;
cal(u);
for(int i=head[u]; i!=-1; i=e[i].next){
int to=e[i].to;
if(vis[to]){
continue;
}
allnode=siz[to];//继续分治
root=0;
getroot(to, u);
work(root);
}
}
int main(){
int u, v, w;
while(~scanf("%d%d", &n, &k)){
memset(head, -1, sizeof(head));
memset(vis, 0, sizeof(vis));
memset(s, -1, sizeof(s));
tot=1;
for(int i=1; i<=n-1; i++){
scanf("%d%d%d", &u, &v, &w);
add(u+1, v+1 ,w), add(v+1, u+1, w);
}
root=0;ans=1<<30;
allnode=n, f[0]=inf;
getroot(1, 0);
work(root);
if(ans==(1<<30)){
printf("-1\n");
}
else{
printf("%d\n",ans);
}
}
return 0;
}