模式字符串
今天考了一道类似点分治的模板题,结果没做出来。
正解就是对每一个分治中心处理出前后缀的个数,统计的时候拼接一下就行了。

#include<iostream>
#include<cstdio>
#define ULL unsigned long long
using namespace std;
int T, n, m, tot, x, y, num, root;
long long ans;
const int N = 1000010, zhi = 13121;
int head[N], siz[N], mx[N], vis[N], f[N], g[N], sf[N], sg[N];
char ch[N], val[N];
ULL h1[N], h2[N], base[N];
struct bian {int to, nt;} e[N << 1];
inline void add(int f, int t) 
{
	e[++tot] = (bian) {t, head[f]};
	head[f] = tot;
}
inline int read() 
{
	int res = 0; char c = getchar();
	while (c < '0' || '9' < c)c = getchar(); while ('0' <= c && c <= '9')res = res * 10 + (c - '0'), c = getchar();
	return res;
}
inline char readc() 
{
	char res = getchar(); while (res < 'A' || 'Z' < res)res = getchar();
	return res;
}
int Gdeep(int x, int fa, int dep, ULL hs) 
{
	hs = hs * zhi + val[x];
	if (hs == h1[dep])++f[(dep - 1) % m + 1], ans += sg[m - (dep - 1) % m];
	if (hs == h2[dep])++g[(dep - 1) % m + 1], ans += sf[m - (dep - 1) % m];
	int tmp = 1;
	for (int i = head[x]; i; i = e[i].nt)
		if (!vis[e[i].to] && e[i].to != fa)tmp = max(tmp, Gdeep(e[i].to, x, dep + 1, hs) + 1);
	return tmp;
}
void Groot(int x, int fa) 
{
	siz[x] = 1; mx[x] = 0;
	for (int i = head[x]; i; i = e[i].nt)
		if (e[i].to != fa && !vis[e[i].to]) 
		{
			Groot(e[i].to, x);
			siz[x] += siz[e[i].to];
			mx[x] = max(mx[x], siz[e[i].to]);
		}
	mx[x] = max(mx[x], num - siz[x]);
	if (mx[x] < mx[root])root = x;
}
void YYCH() 
{
	ans = tot = 0;
	for (int i = 1; i <= n; ++i)head[i] = vis[i] = 0;
}
void solve(int x) 
{
	vis[x] = 1;
	int tmp = 0, k; sg[1] = sf[1] = 1;
	for (int i = head[x]; i; i = e[i].nt)
		if (!vis[e[i].to]) 
		{
			k = min(m, Gdeep(e[i].to, x, 2, val[x]) + 1), tmp = max(tmp, k);
			for (int j = 1; j <= k; ++j)sf[j] += f[j], sg[j] += g[j], f[j] = g[j] = 0;
		}
	for (int i = 1; i <= tmp; ++i)sf[i] = sg[i] = 0;
	for (int i = head[x]; i; i = e[i].nt)
		if (!vis[e[i].to])num = siz[e[i].to], root = 0, Groot(e[i].to, 0), solve(root);
}
void solve() 
{
	YYCH();
	n = read(); m = read();
	for (int i = 1; i <= n; ++i)val[i] = readc();
	for (int i = 1; i < n; ++i)
		x = read(), y = read(), add(x, y), add(y, x);
	scanf("%s", ch + 1); base[0] = 1;
	for (int i = 1; i <= n; ++i)
		base[i] = base[i - 1] * zhi,
		h1[i] = h1[i - 1] + ch[(i - 1) % m + 1] * base[i - 1],
		h2[i] = h2[i - 1] + ch[m - (i - 1) % m] * base[i - 1];
	root = 0; mx[0] = 1 << 30; Groot(1, 0); solve(root);
	printf("%lld\n", ans);
}
int main() 
{
	T = read();
	while (T--)solve();
	fclose(stdin); fclose(stdout);
	return 0;
}