后缀数组、重复子串、回文子串

题意:

图片说明

分析:

刚拿到这题我是崩溃的,这可咋求啊?!AABB的,太吓人了!!

然后我仔细看了看,其实问题的关键在于如何判断一个连续的AA串!!
如果我们能够计算出哪里有AA串就好了。
比如:
a[i]记录了一索引i结尾的AA串有多少个
b[i]记录了一索引i开始的AA串有多少个
然后我们只需要遍历一下原字符串ans+=a[i]*b[i+1]就可以了!

但是,如何计算出哪里有AA串呢???
正当我一筹莫展的时候,突然我意识到:这是一个重复子串问题!!!废话
但是,提起重复子串你会想到什么??最长重复子串问题!!!!!
如果你还不知道什么是最长重复字串问题,我建议你先去学习最长重复字串问题。他是后缀数组的经典应用!

现在你已经知道了如何求解最长重复字串问题了!

然后你将发现本问题和最长重复子串问题极度地相似!
我们同样可以枚举距离i,然后隔着距离i枚举j与j+i
我们求解AA一个经过j,一个经过j+i

如此求出所有的AA组然后记录a与b。

具体如何操作呢?
利用后缀数组我们先求出j与i+j向右方向的最长公共前缀
然后我们再想办法求出j与i+j想左方向的最长公共前缀
这样我们就知道了[left,j,right]他与[left+i,j+i,right+i]相等
但是我们限制是距离i所以应该是
[max(left,j-i+1),j,min(right,j+i-1)]
这里是第一个A的活动范围
具体到A的开头的首位就更小了[max(left,j-i+1),min(right-i+1,j);
我们可以确定将第一个A开头的首位放在这个区间里,一定也是AA的首位

这部分有点难懂,是我讲解的不好。。。。。。

至于,我们如何求解j与j+i左方向的最长公共前缀??
还记得最长回文子串是怎么求的吗?求看看,不用会做,只需知道做法。
我们在这里应用到了同样的技巧,我们reverse一下再求一个SA
或者再中间添一个绝对不会撞车的数字,然后在后面添加reverse后的字符串就好了(稍微慢点)
如此,此题得解

代码如下:

#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
#define re register
const int max_n = 6e4 + 100;
int ranks[max_n], SA[max_n], height[max_n];
int wa[max_n], wb[max_n], wvarr[max_n], wsarr[max_n];
inline int cmp(int* r, int a, int b, int l) {
    return r[a] == r[b] && r[a + l] == r[b + l];
}
inline void get_sa(int* r, int* sa, int n, int m) {
    //r:原数组
    //sa:SA
    //n:原数组长度
    //m:原数组种类数,用于基数排序
    ++n;
    re int i, j, p, * x = wa, * y = wb, * t;
    for (i = 0; i < m; ++i) wsarr[i] = 0;
    for (i = 0; i < n; ++i) wsarr[x[i] = r[i]]++;
    for (i = 1; i < m; ++i) wsarr[i] += wsarr[i - 1];
    for (i = n - 1; i >= 0; --i) sa[--wsarr[x[i]]] = i;
    for (j = 1, p = 1; p < n; j <<= 1, m = p) {
        for (p = 0, i = n - j; i < n; ++i) y[p++] = i;
        for (i = 0; i < n; ++i) if (sa[i] >= j) y[p++] = sa[i] - j;
        for (i = 0; i < n; ++i) wvarr[i] = x[y[i]];
        for (i = 0; i < m; ++i) wsarr[i] = 0;
        for (i = 0; i < n; ++i) wsarr[wvarr[i]]++;
        for (i = 1; i < m; ++i) wsarr[i] += wsarr[i - 1];
        for (i = n - 1; i >= 0; --i) sa[--wsarr[wvarr[i]]] = y[i];
        for (t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; ++i)
            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
    }
}//求解高度数组,height[i]指排名
void get_height(int* r, int* sa, int n) {
    re int i, j, k = 0;
    for (i = 1; i <= n; ++i) ranks[sa[i]] = i;
    for (i = 0; i < n; height[ranks[i++]] = k)
        for (k ? k-- : 0, j = sa[ranks[i] - 1]; r[i + k] == r[j + k]; k++);
    return;
}
//后缀数组

//附加rmq_st 专门对后缀数组
int st[max_n][32];
void initSt(int a[], int n) {
    for (int i = 0;i <= n;++i)st[i][0] = height[i];
    int mxk = (int)log2(n + 1);
    for (int k = 1;k <= mxk;++k) {
        for (int i = 0;i <= n;++i) {
            if (i + (1 << k) - 1 > n)break;
            st[i][k] = min(st[i][k - 1], st[i + (1 << (k - 1))][k - 1]);
        }
    }
}
int que(int l, int r) {
    l = ranks[l];r = ranks[r];
    if (l > r)swap(l, r);
    ++l;
    int len = (int)log2(r - l + 1);
    return min(st[l][len], st[r - (1 << len) + 1][len]);
}

string s;
int a[max_n];
int b[max_n];
int c[max_n];
void init(int n) {
    fill(a, a + n + 3, 0);
    fill(b, b + n + 3, 0);
    fill(c, c + n + 3, 0);
    fill(ranks, ranks + n + 3, 0);
    fill(SA, SA + n + 3, 0);
    fill(height, height + n + 3, 0);
    fill(wa, wa + n + 3, 0);
    fill(wb, wb + n + 3, 0);
    fill(wsarr, wsarr + n + 3, 0);
    fill(wvarr, wvarr + n + 3, 0);
}
int main() {
    ios::sync_with_stdio(0);
    int T;cin >> T;
    while (T--) {
        cin >> s;
        int n = s.size();init(n * 2 + 1);
        for (re int i = 0;i < s.size();++i)a[i] = (s[i] - 'a' + 1);
        reverse(s.begin(), s.end());
        a[n] = 29;
        for (re int i = 0;i < s.size();++i)a[i + n + 1] = (s[i] - 'a' + 1);
        n = n * 2 + 1;
        get_sa(a, SA, n, 30);
        get_height(a, SA, n);
        initSt(a, n);
        for (re int i = 1;i < s.size();++i) {
            for (re int j = 0;j + i < s.size();j += i) {
                if (a[j] != a[i + j])continue;
                int right = que(j, j + i);
                int t1 = n - j - 1;int t2 = n - (j + i) - 1;
                int left = que(t1, t2);
                left = max(j - left + 1, j - i + 1);
                right = min(j, j + right - i);
                for (re int k = left;k <= right;++k) {
                    ++b[k];
                    ++c[k + i + i - 1];
                }
            }
        }ll ans = 0;
        for (re int i = 0;i < s.size();++i)
            ans += (ll)c[i] * b[i + 1];
        cout << ans << endl;
    }
}