题意
一颗树,每个节点有一个字母,对于每一个结点,输出经过他的路径的个数,要求路径上经过的字母可以组成回文串
题解
做这个题之前,只知道边分治牛逼,有种边分治无敌的错觉…紧接着做这个题就被啪啪打脸
边分治根本做不了,只好点分治了
枚举一个结点,找出所有经过他的符合要求的路径,但是,难的地方在于,所有的路径上的点都要统计进去答案
为解决这个问题,可以采用只在路径末端打标记,然后从叶子到根向上传递累加答案的方式
怎么找路径呢?
首先找出从根出发的所有路径的状态
然后对每个子树依次统计
统计时,就是要计算,当前子树的路径走到根,再走到其他子树,有多少种合法路径
首先,要去除子树自身的影响,然后统计,最后恢复影响
最后再计算答案即可
代码
#include<bits/stdc++.h>
#define N 200010
#define INF 0x3f3f3f3f
#define eps 1e-5
#define pi 3.141592653589793
#define mod 998244353
#define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x,y) memset(x,0,sizeof(int)*(y+3))
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
typedef pair<int,int> pp;
int n,tn,sn,cnt,rt,mn,sz[N],q[N],del[N],w[N],ls[N];
int r[20];
LL ans[N],z[N];
vector<int> a[N];
void findrt(int x,int fa){
sz[x]=1; int t=0;
for(auto i:a[x]){
if (i==fa||del[i]) continue;
findrt(i,x);
sz[x]+=sz[i];
t=max(t,sz[i]);
}
t=max(t,tn-sz[x]);
if (t<mn){mn=t; rt=x; }
}
int tet[1<<20];
int k,m;
void pre(int x,int fa,int s){
q[cnt++]=s; ls[cnt-1]=x;
z[x]=0; tet[s]++;
if (__builtin_popcount(s^k)<=1) z[x]++,m++;
for(auto i:a[x]) if (i!=fa&&!del[i]) pre(i,x,s^r[w[i]]);
}
void getans(int x,int fa){
for(auto i:a[x]) if (i!=fa&&!del[i]){
getans(i,x);
z[x]+=z[i];
}
ans[x]+=z[x];
}
int d[N];
void cal(int x){
k=r[w[x]]; m=0; int u=0;
cnt=0; z[x]=0;
for(auto i:a[x])if (!del[i]) {
pre(i,x,r[w[i]]);
d[++u]=cnt;
}
u=0;
for(auto i:a[x]) if (!del[i]){
u++;
for(int j=d[u-1];j<d[u];j++) tet[q[j]]--;
for(int j=d[u-1];j<d[u];j++){
int v=ls[j],s=q[j];
z[v]+=tet[s^k];
for(int i=0;i<20;i++) z[v]+=tet[s^k^r[i]];
}
for(int j=d[u-1];j<d[u];j++) tet[q[j]]++;
}
for(int i=0;i<cnt;i++) tet[q[i]]--;
getans(x,-1);
ans[x]-=(z[x]-m)/2;
}
void solve(int x){
cal(x);del[x]=1;
int tm=tn;
for(auto i:a[x])if(!del[i]){
tn=sz[x]>sz[i]?sz[i]:tm-sz[x];
mn=INF;
findrt(i,-1);
solve(rt);
}
}
char ch[N];
int main(int argc, char const *argV[]){
for(int i=0;i<20;i++) r[i]=1<<i;
sc(n);
for(int i=1;i<n;i++){
int x,y; scc(x,y);
a[x].pb(y);a[y].pb(x);
}
scanf("%s",ch+1);
for(int i=1;i<=n;i++) w[i]=ch[i]-'a';
mn=INF; tn=n;
findrt(1,-1);
solve(rt);
for(int i=1;i<=n;i++) printf("%lld%c",ans[i]+1," \n"[i==n]);
return 0;
}