1.前言
去年湖南省赛有个弱化版的题.点分治和树形dp都可以写.以后一个月不会刻意写题了,准备考试,上海站加油吧,相信自己也相信队友~
https://ac.nowcoder.com/acm/contest/1099/I
2.思路
对于题目给定的,假如我们考虑分治一个点,假如分治1,那么我肯定是统计1的所经过链的答案,但是呢,我们肯定会出现不合法的状态,比如我从1到4到6,然后从1到4到5,显然不是一条链,但是这样被计算了,我们只要保留1 4这条边,容斥一下就可以了.当然这只是讲点分治,并没有讲这个题,关于这个题呢,我们可以考虑从1个节点往下数出一个数为a,从下往上数为b.很显然题目就是要你计算
(b*10^dep(a)+a)%m=0
的点对个数,但是这样并不太好计算,因为我们在统计答案的时候a和a不在一起,这也好办,直接把原式/10^(dep(a))即可.因为m和10互质m所以:现在就是求
(b+a*iv(10^dep(a)))%m=0
的点对数量,直接点分治维护即可...emmm,我写了几个dsu on tree发现,原来我之前的点分治题单并没有做多少...可恶啊~
3.代码
#include <bits/stdc++.h> using namespace std; #define mp make_pair #define int long long typedef long long ll; const int N=1e5+50; struct Tree{ int to,val; }; int root=0; vector<Tree>v[N]; map<int,int>s; pair<int,int>dig[N<<2]; ll ans=0,p[N]; int dep[N]; int n,m,num=0; inline int exgcd(int a,int b,int& x,int& y) { if(!b) { x=1,y=0; return a; } int d=exgcd(b,a%b,x,y); int z=x; x=y; y=z-a/b*y; return d; } inline int inv(int a,int m)//a在mod m意义下的逆元 { int x,y,d=exgcd(a,m,x,y); return d==1?(x%m+m)%m:-1; } bool vis[N];int sz[N],f[N],sum; void f_root(int u,int fa) { sz[u]=1,f[u]=0; for(int i=0;i<(int)v[u].size();i++) { int son=v[u][i].to; if(son==fa||vis[son]) continue; f_root(son,u);sz[u]+=sz[son]; f[u]=max(f[u],sz[son]); }f[u]=max(f[u],sum-f[u]); if(f[root]>f[u]) root=u; } //把根节点放到下往上数更容易~ void init(int u,int fa,int p1,int p2,int Dep)//当前节点,父亲节点,下往上,上往下,深度. { if(Dep>=0) s[p1]++,dig[++num]=mp(p2,Dep); for(int i=0;i<(int)v[u].size();i++) { int son=v[u][i].to;int w=v[u][i].val; if(fa==son||vis[son]) continue; int p3=(p[Dep+1]*w+p1)%m; int p4=(p2*10+w)%m; init(son,u,p3,p4,Dep+1); } } int cal(int u,int d) { s.clear(); int res=0;num=0; if(d) init(u,0,d%m,d%m,0); else init(u,0,0,0,-1); for(int i=1;i<=num;i++) { int temp=((-dig[i].first*inv(p[dig[i].second+1],m))%m+m)%m; res+=s[temp]; //if(s.find(temp)!=s.end()) res+=s[temp]; if(!d) res+=(!dig[i].first); }if(!d) res+=s[0]; return res; } void solve(int u) { ans+=cal(u,0);vis[u]=true; for(int i=0;i<(int)v[u].size();i++) { int son=v[u][i].to;int Val=v[u][i].val; if(vis[son]) continue; ans-=cal(son,Val); sum=sz[son];root=0;f[0]=n;f_root(son,u); solve(root); } } signed main() { scanf("%d%d",&n,&m); for(int i=1;i<n;i++) { int a,b,c; scanf("%lld%lld%lld",&a,&b,&c); a++,b++; v[a].push_back({b,c}); v[b].push_back({a,c}); }p[0]=1; for(int i=1;i<=n;i++) p[i]=p[i-1]*10%m; f[0]=n,root=0,sum=n;f_root(1,0); solve(root); printf("%lld\n",ans); }