模式字符串
今天考了一道类似点分治的模板题,结果没做出来。
正解就是对每一个分治中心处理出前后缀的个数,统计的时候拼接一下就行了。
#include<iostream>
#include<cstdio>
#define ULL unsigned long long
using namespace std;
int T, n, m, tot, x, y, num, root;
long long ans;
const int N = 1000010, zhi = 13121;
int head[N], siz[N], mx[N], vis[N], f[N], g[N], sf[N], sg[N];
char ch[N], val[N];
ULL h1[N], h2[N], base[N];
struct bian {int to, nt;} e[N << 1];
inline void add(int f, int t)
{
e[++tot] = (bian) {t, head[f]};
head[f] = tot;
}
inline int read()
{
int res = 0; char c = getchar();
while (c < '0' || '9' < c)c = getchar(); while ('0' <= c && c <= '9')res = res * 10 + (c - '0'), c = getchar();
return res;
}
inline char readc()
{
char res = getchar(); while (res < 'A' || 'Z' < res)res = getchar();
return res;
}
int Gdeep(int x, int fa, int dep, ULL hs)
{
hs = hs * zhi + val[x];
if (hs == h1[dep])++f[(dep - 1) % m + 1], ans += sg[m - (dep - 1) % m];
if (hs == h2[dep])++g[(dep - 1) % m + 1], ans += sf[m - (dep - 1) % m];
int tmp = 1;
for (int i = head[x]; i; i = e[i].nt)
if (!vis[e[i].to] && e[i].to != fa)tmp = max(tmp, Gdeep(e[i].to, x, dep + 1, hs) + 1);
return tmp;
}
void Groot(int x, int fa)
{
siz[x] = 1; mx[x] = 0;
for (int i = head[x]; i; i = e[i].nt)
if (e[i].to != fa && !vis[e[i].to])
{
Groot(e[i].to, x);
siz[x] += siz[e[i].to];
mx[x] = max(mx[x], siz[e[i].to]);
}
mx[x] = max(mx[x], num - siz[x]);
if (mx[x] < mx[root])root = x;
}
void YYCH()
{
ans = tot = 0;
for (int i = 1; i <= n; ++i)head[i] = vis[i] = 0;
}
void solve(int x)
{
vis[x] = 1;
int tmp = 0, k; sg[1] = sf[1] = 1;
for (int i = head[x]; i; i = e[i].nt)
if (!vis[e[i].to])
{
k = min(m, Gdeep(e[i].to, x, 2, val[x]) + 1), tmp = max(tmp, k);
for (int j = 1; j <= k; ++j)sf[j] += f[j], sg[j] += g[j], f[j] = g[j] = 0;
}
for (int i = 1; i <= tmp; ++i)sf[i] = sg[i] = 0;
for (int i = head[x]; i; i = e[i].nt)
if (!vis[e[i].to])num = siz[e[i].to], root = 0, Groot(e[i].to, 0), solve(root);
}
void solve()
{
YYCH();
n = read(); m = read();
for (int i = 1; i <= n; ++i)val[i] = readc();
for (int i = 1; i < n; ++i)
x = read(), y = read(), add(x, y), add(y, x);
scanf("%s", ch + 1); base[0] = 1;
for (int i = 1; i <= n; ++i)
base[i] = base[i - 1] * zhi,
h1[i] = h1[i - 1] + ch[(i - 1) % m + 1] * base[i - 1],
h2[i] = h2[i - 1] + ch[m - (i - 1) % m] * base[i - 1];
root = 0; mx[0] = 1 << 30; Groot(1, 0); solve(root);
printf("%lld\n", ans);
}
int main()
{
T = read();
while (T--)solve();
fclose(stdin); fclose(stdout);
return 0;
}