题意

给一个字符串,让我们找每个长度的子串中,是super串的个数。(类似双倍回文)

分析

我们对原串建立一个PAM,这样我们可以统计每种回文串出现次数,在用hash判断是不是super串
最后在统计一下就ok

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N = 3e5 + 5;
const ull base = 131;
const ull mod = 1e9 + 7;
struct PAM_node
{
    int vis[26];
    int fail, len, num;
};
ull ha[N], se[N];
ull get_hash(int l, int r)//神奇的hash操作,ull自动取模
{
    return ha[r] - ha[l - 1] * se[r - l + 1];
}
char str[N];
struct PAM
{
    PAM_node pn[N];
    int n, length, last, cnt, s[N], ans[N];
    bool flag[N];
    void init()
    {
        pn[0].len = 0;
        pn[1].len = -1;
        pn[0].fail = 1;
        pn[1].fail = 0;
        clearn(0);
        clearn(1);
        last = 0;
        cnt = 1;
        n = 0;
        memset(ans, 0, sizeof(ans));
        memset(flag, false, sizeof(flag));
    }
    void clearn(int x)//注意清空
    {
        memset(pn[x].vis, 0, sizeof(pn[x].vis));
        pn[x].num = 0;
    }
    int get_fail(int x)
    {
        while(s[n - pn[x].len - 1] != s[n])
            x = pn[x].fail;
        return x;
    }
    void insert()
    {
        int p = get_fail(last);
        if(!pn[p].vis[s[n]])
        {
            pn[++ cnt].len = pn[p].len + 2;
            int tmp = get_fail(pn[p].fail);
            pn[cnt].fail = pn[tmp].vis[s[n]];
            pn[p].vis[s[n]] = cnt;
            clearn(cnt);
            int need = (pn[cnt].len + 1) / 2;
            if(pn[cnt].len == 1 || get_hash(n - pn[cnt].len + 1, n - pn[cnt].len + need) == get_hash(n - need + 1, n))//判断是不是super串
                flag[cnt] = true;
        }
        last = pn[p].vis[s[n]];
        pn[last].num ++;
    }
    void count()//统计数量
    {
        for(int i = cnt; i >= 1; i --)
        {
            pn[pn[i].fail].num += pn[i].num;
            if(flag[i])
                ans[pn[i].len] += pn[i].num;
        }
    }
    void solve()
    {
        s[0] = 26;
        for(n = 1; n <= length; n ++)
        {
            s[n] = str[n] - 'a';
            insert();
        }
        count();
        for(int i = 1; i <= length; i ++)
            printf("%d%s", ans[i], (i != length ? " " : "\n"));
    }
}pam;
int main()
{
    se[0] = 1;
    for(int i = 1; i < N; i ++)
        se[i] = se[i - 1] * base;  
    while(cin >> (str + 1))
    {
        int len = strlen(str + 1);
        ha[0] = 1;
        for(int i = 1; i <= len; i ++)//字符串哈希
        	ha[i] = ha[i - 1] * base + str[i]; 
        pam.init();
        pam.length = len;
        pam.solve();
    }
    return 0;
}