【题意】求树上距离小于等于K的点对有多少个?

【解题方法】不愧是男人8题,从TLE写带WA,最后过了,经历了10+次。

一个重要的问题是,为了防止退化,所以每次都要找到树的重心然后分治下去,所谓重心,就是删掉此结点后,剩下的结点最多的树结点个数最小。

每次分治,我们首先算出重心,为了计算重心,需要进行两次dfs,第一次把以每个结点为根的子树大小求出来,第二次是从这些结点中找重心

找到重心后,需要统计所有结点到重心的距离,看其中有多少对小于等于K,这里采用的方法就是把所有的距离存在一个数组里,进行快速排序,这是nlogn的,然后用一个经典的相向搜索O(n)时间内解决。但是这些求出来满足小于等于K的里面只有那些路径经过重心的点对才是有效的,也就是说在同一颗子树上的肯定不算数的,所以对每颗子树,把子树内部的满足条件的点对减去。

最后的复杂度是n logn logn    其中每次快排是nlogn 而递归的深度为logn。我开始写的时候,TLE了8次,但是还是找不到TLE在什么地方,最后换姿势写过了,路过菊苣如果看到我哪里写T了,欢迎指出来下,不胜感激。。

【TLE 版本】

//tree merge

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 10005;
const int inf  = 0x3f3f3f3f;
int n,k,ans,he,numson;
int head[maxn],tot;
int cnt;
int siz[maxn],dis[maxn];
int vis[maxn];
struct edge{
    int to,next,w;
}E[maxn*4];
void init(){
    cnt=0;
    memset(head,-1,sizeof(head));
}
void addedge(int u,int v,int w){
    E[cnt].to=v,E[cnt].w=w,E[cnt].next=head[u],head[u]=cnt++;
}
//int zx(int u,int fa){
//    siz[u]=1;
//    int he=-1;
//    for(int i=head[u]; ~i; i=E[i].next){
//        int v=E[i].to;
//        if(v==u) continue;
//        zx(v,u);
//        siz[u]+=siz[v];
//        he=max(he,siz[v]);
//    }
//    he=max(he,n-siz[u]);
//    return he;
//}
void zx(int u,int fa){
    siz[u]=1;
    int temp=0;
    for(int i=head[u]; i+1; i=E[i].next){
        int v=E[i].to;
        if(v==fa||vis[v]) continue;
        zx(v,u);
        siz[u]+=siz[v];
        temp=max(temp,siz[v]);
    }
    temp=max(temp,tot-siz[u]);
    if(temp<numson){
        temp=numson;
        he=u;
    }
}
void dfs1(int u,int fa,int d){
    dis[tot++]=d;
    for(int i=head[u]; i+1; i=E[i].next){
        int v=E[i].to,w=E[i].w;
        if(v==fa||vis[v])  continue;
        dfs1(v,u,d+w);
    }
}
int cal(int u,int d){
    tot=0;
    dfs1(u,u,d);
    siz[u]=tot;
    sort(dis,dis+tot);
    //two pointers query ans.
//    int i=0,j=0;
//    for(int i=0,j=0; i< n; i++){
//        while()
//    }
    int sum = 0;
    for(int i=0,j=tot-1; i<j; ){
        if(dis[i]+dis[j]<=k){
            sum += (j-i);
            i++;
        }else{
            j--;
        }
    }
    return sum;
}

void dfs2(int u){
    tot=siz[u];
    he=u;
    numson=n;
    zx(u,u);
    //cout<<he<<endl;
    vis[he]=1;
    ans+=cal(he,0);
    for(int i=head[he]; i+1; i=E[i].next){
        int v=E[i].to,w=E[i].w;
        if(vis[v]) continue;
        ans-=cal(v,w);
        dfs2(v);
    }
}

int main()
{
    while(scanf("%d%d",&n,&k)!=EOF)
    {
        if(n==0&&k==0) break;
        init();
        int u,v,w;
        for(int i=1; i<n; i++){
            scanf("%d%d%d",&u,&v,&w);
            addedge(u,v,w);
            addedge(v,u,w);
        }
        memset(vis,0,sizeof(vis));
        ans=0;
        siz[1]=n;
        dfs2(1);
        printf("%d\n",ans);
    }
    return 0;
}

【AC 代码】

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 10005;
int n,k,num,ans,root,tot;
int head[maxn],vis[maxn];
int siz[maxn],mx[maxn],mi,dis[maxn];
struct edge{
    int to,next,w;
}E[maxn*2];
void init(){
    memset(vis,0,sizeof(vis));
    memset(head,-1,sizeof(head));
    tot=0;
}
void addedge(int u,int v,int w){
    E[tot].to=v,E[tot].w=w,E[tot].next=head[u],head[u]=tot++;
}
void dfs1(int u,int fa){
    siz[u]=1;mx[u]=0;
    for(int i=head[u]; ~i; i=E[i].next){
        int v=E[i].to;
        if(v==fa||vis[v]) continue;
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>mx[u]) mx[u]=siz[v];
    }
}
void dfs2(int r,int u,int fa){
    if(siz[r]-siz[u]>mx[u]) mx[u]=siz[r]-siz[u];
    if(mx[u]<mi) mi=mx[u],root=u;
    for(int i=head[u]; ~i; i=E[i].next){
        int v=E[i].to;
        if(v==fa||vis[v]) continue;
        dfs2(r,v,u);
    }
}
void dfs3(int d,int u,int fa){
    dis[num++]=d;
    for(int i=head[u]; ~i; i=E[i].next){
        int v=E[i].to;
        if(v==fa||vis[v]) continue;
        dfs3(d+E[i].w,v,u);
    }
}
int cal(int u,int d){
    num=0;
    dfs3(d,u,-1);
    sort(dis,dis+num);
    int sum=0;
    for(int i=0,j=num-1; i<j; ){
        if(dis[i]+dis[j]<=k){
            sum+=(j-i);
            i++;
        }else{
            j--;
        }
    }
    return sum;
}
void dfs4(int u){
    mi=n;
    dfs1(u,-1);
    dfs2(u,u,-1);
    vis[root]=1;
    //cout<<root<<endl;
    ans+=cal(root,0);
    for(int i=head[root]; ~i; i=E[i].next){
        int v=E[i].to;
        if(vis[v]) continue;
        ans-=cal(v,E[i].w);
        dfs4(v);
    }
}
int main()
{
    while(scanf("%d%d",&n,&k)!=EOF)
    {
        if(n==0&&k==0) break;
        init();
        ans=0;
        int u,v,w;
        for(int i=1; i<n; i++){
            scanf("%d%d%d",&u,&v,&w);
            addedge(u,v,w);
            addedge(v,u,w);
        }
        dfs4(1);
        printf("%d\n",ans);
    }
    return 0;
}