1.前言
去年湖南省赛有个弱化版的题.点分治和树形dp都可以写.以后一个月不会刻意写题了,准备考试,上海站加油吧,相信自己也相信队友~
https://ac.nowcoder.com/acm/contest/1099/I
2.思路
对于题目给定的,假如我们考虑分治一个点,假如分治1,那么我肯定是统计1的所经过链的答案,但是呢,我们肯定会出现不合法的状态,比如我从1到4到6,然后从1到4到5,显然不是一条链,但是这样被计算了,我们只要保留1 4这条边,容斥一下就可以了.当然这只是讲点分治,并没有讲这个题,关于这个题呢,我们可以考虑从1个节点往下数出一个数为a,从下往上数为b.很显然题目就是要你计算
(b*10^dep(a)+a)%m=0
的点对个数,但是这样并不太好计算,因为我们在统计答案的时候a和a不在一起,这也好办,直接把原式/10^(dep(a))即可.因为m和10互质m所以:现在就是求
(b+a*iv(10^dep(a)))%m=0
的点对数量,直接点分治维护即可...emmm,我写了几个dsu on tree发现,原来我之前的点分治题单并没有做多少...可恶啊~
3.代码
#include <bits/stdc++.h>
using namespace std;
#define mp make_pair
#define int long long
typedef long long ll;
const int N=1e5+50;
struct Tree{
int to,val;
};
int root=0;
vector<Tree>v[N];
map<int,int>s;
pair<int,int>dig[N<<2];
ll ans=0,p[N];
int dep[N];
int n,m,num=0;
inline int exgcd(int a,int b,int& x,int& y)
{
if(!b) { x=1,y=0; return a; }
int d=exgcd(b,a%b,x,y);
int z=x; x=y; y=z-a/b*y;
return d;
}
inline int inv(int a,int m)//a在mod m意义下的逆元
{
int x,y,d=exgcd(a,m,x,y);
return d==1?(x%m+m)%m:-1;
}
bool vis[N];int sz[N],f[N],sum;
void f_root(int u,int fa)
{
sz[u]=1,f[u]=0;
for(int i=0;i<(int)v[u].size();i++)
{
int son=v[u][i].to;
if(son==fa||vis[son]) continue;
f_root(son,u);sz[u]+=sz[son];
f[u]=max(f[u],sz[son]);
}f[u]=max(f[u],sum-f[u]);
if(f[root]>f[u]) root=u;
}
//把根节点放到下往上数更容易~
void init(int u,int fa,int p1,int p2,int Dep)//当前节点,父亲节点,下往上,上往下,深度.
{
if(Dep>=0) s[p1]++,dig[++num]=mp(p2,Dep);
for(int i=0;i<(int)v[u].size();i++)
{
int son=v[u][i].to;int w=v[u][i].val;
if(fa==son||vis[son]) continue;
int p3=(p[Dep+1]*w+p1)%m;
int p4=(p2*10+w)%m;
init(son,u,p3,p4,Dep+1);
}
}
int cal(int u,int d)
{
s.clear();
int res=0;num=0;
if(d) init(u,0,d%m,d%m,0);
else init(u,0,0,0,-1);
for(int i=1;i<=num;i++)
{
int temp=((-dig[i].first*inv(p[dig[i].second+1],m))%m+m)%m;
res+=s[temp];
//if(s.find(temp)!=s.end()) res+=s[temp];
if(!d) res+=(!dig[i].first);
}if(!d) res+=s[0];
return res;
}
void solve(int u)
{
ans+=cal(u,0);vis[u]=true;
for(int i=0;i<(int)v[u].size();i++)
{
int son=v[u][i].to;int Val=v[u][i].val;
if(vis[son]) continue;
ans-=cal(son,Val);
sum=sz[son];root=0;f[0]=n;f_root(son,u);
solve(root);
}
}
signed main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int a,b,c;
scanf("%lld%lld%lld",&a,&b,&c);
a++,b++;
v[a].push_back({b,c});
v[b].push_back({a,c});
}p[0]=1;
for(int i=1;i<=n;i++) p[i]=p[i-1]*10%m;
f[0]=n,root=0,sum=n;f_root(1,0);
solve(root);
printf("%lld\n",ans);
}

京公网安备 11010502036488号