题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6599
题意:给出一个长度为N的字符串,要求输出一个长度为N的数组A, A[i]表示长度为i的good substring的数量
good substring 的定义是 该子串是回文串,且该子串的一半也是回文串。
题解:回文自动机找出所有的本质不相同的回文串,再用Manacher判断其一半是否也是回文。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+10;// n(空间复杂度o(n*ALP)),实际开n即可
const int ALP = 26;
char s[maxn], m[maxn * 2];
int p[maxn * 2];
int ans[maxn];
struct PAM // 每个节点代表一个回文串
{
int next[maxn][ALP]; // next指针,参照Trie树
int fail[maxn]; // fail失配后缀链接
int cnt[maxn]; // 此回文串出现个数
int num[maxn]; //表示以节点i表示的最长回文串的最右端点为回文串结尾的回文串个数
int len[maxn]; //回文串长度
int s[maxn]; // 存放添加的字符
int last; //指向上一个字符所在的节点,方便下一次add
int n; // 已添加字符个数
int p; // 节点个数
int endpos[maxn];//出现的最后位置
int newnode(int w) { // 初始化节点,w=长度
for (int i = 0; i < ALP; i++)
next[p][i] = 0;
cnt[p] = 0;
num[p] = 0;
len[p] = w;
return p++;
}
void init() {
p = 0;
newnode(0);
newnode(-1);
last = 0;
n = 0;
s[n] = -1; // 开头放一个字符集中没有的字符,减少特判
fail[0] = 1;
}
int get_fail(int x) // 和KMP一样,失配后找一个尽量最长的
{
while (s[n - len[x] - 1] != s[n]) x = fail[x];
return x;
}
void add(int c,int pos)
{
c -= 'a'; //注意修改从‘a’开始
s[++n] = c;
int cur = get_fail(last);
if (!next[cur][c])
{
int now = newnode(len[cur] + 2);
fail[now] = next[get_fail(fail[cur])][c];
next[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = next[cur][c];
cnt[last]++;
endpos[last] = pos;
}
void count() // 最后统计一遍每个节点出现个数
{
for (int i = p - 1; i >= 0; i--)
cnt[fail[i]] += cnt[i];
}
}run;
void Manacher()
{
int r = 0, id = 0;
for (int i = 0; m[i] != '\0'; i++)
{
if (i < r)
{
p[i] = min(p[2 * id - i], r - i);
}
while (m[i + p[i]] == m[i - p[i]])
{
p[i]++;
}
if (i + p[i] > r)
{
r = i + p[i];
id = i;
}
}
}
int main()
{
while (~scanf("%s", &s))
{
run.init();
int n = strlen(s);
//回文自动机
for (int i = 0; i < n; i++) run.add(s[i],i+1);
run.count();
//manacher
int j = 0;
m[j++] = '@';
for (int i = 0; i < n; i++)
{
p[j] = 0;
m[j++] = '#';
p[j] = 0;
m[j++] = s[i];
p[j] = 0;
}
m[j++] = '#';
m[j] = '\0';
Manacher();
for (int i = 2; i <= run.p - 1; i++)
{
if (p[((run.endpos[i] - (run.len[i]) / 2) * 2 + (run.endpos[i] +1- run.len[i]) * 2) / 2] - 1 >= (run.len[i] + 1) / 2)
ans[run.len[i]] += run.cnt[i];
}
for (int i = 1; i <= n; i++)
{
if (i == 1)
cout << ans[i];
else
cout << ' ' << ans[i];
ans[i] = 0;
}
cout << endl;
}
}