边分治模板题

代码

#include<bits/stdc++.h>
#define N 80010
#define INF 0x3f3f3f3f
#define eps 1e-5
#define pi 3.141592653589793
#define mod 998244353
#define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x,y) memset(x,0,sizeof(int)*(y+3))
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
typedef  pair<int,int> pp;
int n,cnt=1,la[N],sn,sz[N],del[N],mn,ct,tn;
 
vector<pp> a[N];
 
struct node{
    int from,to,w,nxt;
}G[N],A[N];
 
void add(int x,int y,int w){
    G[++cnt]={x,y,w,la[x]}; la[x]=cnt;
}
 
void rebuild(int x,int fa){
    int pre=0;
    for(auto i:a[x]){
        int v=i.fi,w=i.se;
        if (v==fa) continue;
        if (!pre){
            add(x,v,w),add(v,x,w); pre=x;
        }else{
            int k=++tn;
            add(k,v,w), add(v,k,w);
            add(k,pre,0), add(pre,k,0);
            pre=k;
        }
        rebuild(v,x);
    }
}
 
void findct(int x,int fa){
    sz[x]=1;
    for(int i=la[x];i;i=G[i].nxt){
        int v=G[i].to;
        if (del[i>>1]||v==fa) continue;
        findct(v,x);
        sz[x]+=sz[v];
        int tmp=max(sz[v],sn-sz[v]);
        if (tmp<mn){
            ct=i;
            mn=tmp;
        }
    }
}
 
LL ans=0;
 
unordered_map<int,int> mp;
int cc;int tet[2200];
int q[N];
 
void gao(int x,int fa,int dis){
    if (x<=n) tet[dis%2019]++,q[cc++]=dis%2019;
    for(int i=la[x];i;i=G[i].nxt){
        int v=G[i].to;
        if (v==fa||del[i>>1]) continue;
        gao(v,x,dis+G[i].w);
    }
}
 
 
void get(int x,int fa,int dis,int y){
    if (x<=n)ans+=tet[(2019*3-dis%2019-y%2019)%2019];
    for(int i=la[x];i;i=G[i].nxt){
        int v=G[i].to;
        if (v==fa||del[i>>1]) continue;
        get(v,x,dis+G[i].w,y);
    }
}
 
void dfs(int x){
    int u=G[x].from,v=G[x].to;
    del[x>>1]=1;
    cc=0;
    if (sz[u]<sz[v]) swap(u,v);
    gao(u,-1,0);
    get(v,-1,0,G[x].w);
    for(int i=0;i<cc;i++) tet[q[i]]=0;
    int tot=sn;
    sn=tot-sz[v]; mn=INF;
    findct(u,-1);
    if (mn!=INF)dfs(ct);
    sn=sz[v]; mn=INF;
    findct(v,-1);
    if (mn!=INF)dfs(ct);
}
 
int main(){
    while(~sc(n)){
        cnt=1; ans=0;
        for(int i=1;i<=n*2;i++) la[i]=0,del[i]=0,a[i].cl();
        for(int i=1;i<n;i++){
            int x,y,w;
            sccc(x,y,w);
            a[x].pb(pp(y,w)); a[y].pb(pp(x,w));
        }  
        tn=n;
        rebuild(1,-1);
        sn=tn; mn=INF;
        findct(1,-1);
        dfs(ct);
        printf("%lld\n",ans);
    }
    return 0;
}