1.前言

去年湖南省赛有个弱化版的题.点分治和树形dp都可以写.以后一个月不会刻意写题了,准备考试,上海站加油吧,相信自己也相信队友~
https://ac.nowcoder.com/acm/contest/1099/I


2.思路

图片说明
对于题目给定的,假如我们考虑分治一个点,假如分治1,那么我肯定是统计1的所经过链的答案,但是呢,我们肯定会出现不合法的状态,比如我从1到4到6,然后从1到4到5,显然不是一条链,但是这样被计算了,我们只要保留1 4这条边,容斥一下就可以了.当然这只是讲点分治,并没有讲这个题,关于这个题呢,我们可以考虑从1个节点往下数出一个数为a,从下往上数为b.很显然题目就是要你计算

(b*10^dep(a)+a)%m=0

的点对个数,但是这样并不太好计算,因为我们在统计答案的时候a和a不在一起,这也好办,直接把原式/10^(dep(a))即可.因为m和10互质m所以:现在就是求

(b+a*iv(10^dep(a)))%m=0

的点对数量,直接点分治维护即可...emmm,我写了几个dsu on tree发现,原来我之前的点分治题单并没有做多少...可恶啊~


3.代码

#include <bits/stdc++.h>
using namespace std;
#define mp make_pair
#define int long long
typedef long long ll;
const int N=1e5+50;
struct Tree{
    int to,val;
};
int root=0;
vector<Tree>v[N];
map<int,int>s;
pair<int,int>dig[N<<2];
ll ans=0,p[N];
int dep[N];
int n,m,num=0;
inline int exgcd(int a,int b,int& x,int& y)
{
    if(!b) { x=1,y=0; return a; }
    int d=exgcd(b,a%b,x,y);
    int z=x; x=y; y=z-a/b*y;
    return d;
}

inline int inv(int a,int m)//a在mod m意义下的逆元
{
    int x,y,d=exgcd(a,m,x,y);
    return d==1?(x%m+m)%m:-1;
}

bool vis[N];int sz[N],f[N],sum;
void f_root(int u,int fa)
{
    sz[u]=1,f[u]=0;
    for(int i=0;i<(int)v[u].size();i++)
    {
        int son=v[u][i].to;
        if(son==fa||vis[son])   continue;
        f_root(son,u);sz[u]+=sz[son];
        f[u]=max(f[u],sz[son]);
    }f[u]=max(f[u],sum-f[u]);
    if(f[root]>f[u])    root=u;
}

//把根节点放到下往上数更容易~
void init(int u,int fa,int p1,int p2,int Dep)//当前节点,父亲节点,下往上,上往下,深度.
{
    if(Dep>=0)  s[p1]++,dig[++num]=mp(p2,Dep);
    for(int i=0;i<(int)v[u].size();i++)
    {
        int son=v[u][i].to;int w=v[u][i].val;
        if(fa==son||vis[son])   continue;
        int p3=(p[Dep+1]*w+p1)%m;
        int p4=(p2*10+w)%m;
        init(son,u,p3,p4,Dep+1);
    }
}


int cal(int u,int d)
{
    s.clear();
    int res=0;num=0;
    if(d)   init(u,0,d%m,d%m,0);
    else    init(u,0,0,0,-1);
    for(int i=1;i<=num;i++)
    {
        int temp=((-dig[i].first*inv(p[dig[i].second+1],m))%m+m)%m;
        res+=s[temp];
        //if(s.find(temp)!=s.end()) res+=s[temp];
        if(!d)  res+=(!dig[i].first);
    }if(!d) res+=s[0];
    return res;
}

void solve(int u)
{
    ans+=cal(u,0);vis[u]=true;
    for(int i=0;i<(int)v[u].size();i++)
    {
        int son=v[u][i].to;int Val=v[u][i].val;
        if(vis[son])    continue;
        ans-=cal(son,Val);
        sum=sz[son];root=0;f[0]=n;f_root(son,u);
        solve(root);
    }
}

signed main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int a,b,c;
        scanf("%lld%lld%lld",&a,&b,&c);
        a++,b++;
        v[a].push_back({b,c});
        v[b].push_back({a,c});
    }p[0]=1;
    for(int i=1;i<=n;i++)   p[i]=p[i-1]*10%m;
    f[0]=n,root=0,sum=n;f_root(1,0);
    solve(root);
    printf("%lld\n",ans);
}