之前用树形dp写的,这次用点分治写一下,不解释~

#include <bits/stdc++.h>
using namespace std;
const int N=2e4+500;
const int mod=2019;
struct Tree{
    int to,val;
};
vector<Tree>v[N];
bool vis[N];
int sz[N],f[N],sum,root;
long long ans=0;
map<int,int>s;
void f_root(int u,int fa)
{
    sz[u]=1;f[u]=0;
    for(int i=0;i<(int)v[u].size();i++)
    {
        int son=v[u][i].to;
        if(son==fa||vis[son])   continue;
        f_root(son,u);sz[u]+=sz[son];
        f[u]=max(f[u],sz[son]);
    }f[u]=max(f[u],sum-f[u]);
    if(f[root]>f[u])    root=u;
}

void Cnt_node(int u,int fa,int cnt)
{
    for(int i=0;i<(int)v[u].size();i++)
    {
        int son=v[u][i].to;int Val=v[u][i].val;
        int pos=(cnt+Val)%mod;
        if(son==fa||vis[son])   continue;
        s[pos]++;
        Cnt_node(son,u,pos);
    }
}

void Cnt_ans(int u,int fa,int cnt)
{
    for(int i=0;i<(int)v[u].size();i++)
    {
        int son=v[u][i].to;int Val=v[u][i].val;
        int pos=(cnt+Val)%mod;
        if(son==fa||vis[son])   continue;
        if(pos%mod==0) ans++;ans+=s[(mod-(pos)%mod)%mod];
        Cnt_ans(son,u,pos);
    }
}

void cal(int u,int fa,int cnt)
{
    for(int i=0;i<(int)v[u].size();i++)
    {
        int son=v[u][i].to;int Val=v[u][i].val;
        int pos=cnt+Val;
        if(son==fa||vis[son])   continue;
        if(Val%mod==0)  ans++;ans+=s[(mod-(Val%mod))%mod];
        Cnt_ans(son,u,pos);
        s[Val]++;
        Cnt_node(son,u,pos);
    }
}

void solve(int u)
{
    cal(u,0,0);
    s.clear();vis[u]=true;
    for(int i=0;i<(int)v[u].size();i++)
    {
        int son=v[u][i].to;
        if(vis[son])    continue;
        root=0;sum=sz[son];
        f_root(son,u);
        solve(son);
    }
}

int main()
{
    int n;
    while(scanf("%d",&n)!=EOF)
    {
        s.clear();ans=0;
        for(int i=1;i<=n;i++)   v[i].clear();
        for(int i=1;i<n;i++)
        {
            int a,b,c;
            scanf("%d%d%d",&a,&b,&c);
            v[a].push_back({b,c});
            v[b].push_back({a,c});
        }root=0;f[0]=n,sum=n;
        memset(vis,false,sizeof vis);
        f_root(1,0);
        solve(root);
        printf("%lld\n",ans);
    }
    return 0;
}