D Mi Re Do Si La? So Fa!

题意:给定串 SS,询问其每个子串的完整循环节长度和。例如串 aaaa的完整循环节长度有 1,2,41,2,4 三种。S1×106|S| \leq 1\times 10^6

解法:显然每个子串 S[l,r]S[l,r] 都有长度为 rl+1r-l+1 的循环节。考虑循环节长度不为子串本身的循环节长度:

在这里插入图片描述

其中红色块代表长度为 kk 的循环节,那么如果串中存在子串有长度为 kk 的完整循环节,那么对于该串覆盖的任意的分段点(如上图中的 ik,(i+1)k,(i+2)kik,(i+1)k, (i+2)k),则都有这若干处分段点开始(图中仅三处,可以更长)的总 lcp\rm lcp 长度与总 lcs\rm lcs 的长度和大于等于 kk

枚举长度 kk 再枚举分段点的位置 ikik,设 XXikik(i+1)k(i+1)k 处的 lcp\rm lcp 长度,YYikik 处与 (i+1)k(i+1)klcs\rm lcs 长度,统计起始位置在 ((i1)k,ik]((i-1)k,ik] 的所有子串,则满足完整循环节长度为 kk 的串个数等于从 ikik 处向前选择非空的一段、从 ikik 处再向后选择长度不超过 YY 的一段,最终将两段拼接起来长度为 kk 的倍数的方案数,即求 i=1Xj=0Y[i+j0(modk)]\displaystyle \sum_{i=1}^X\sum_{j=0}^Y[i+j \equiv 0 \pmod k],分类讨论即可。

最终总时间复杂度为 O(n2n)\mathcal O(n \log^2n)

#include <bits/stdc++.h>
using namespace std;
const int N = 1'000'000;
int lg[N + 5];
class RMQ
{
    const int N = 20;
    vector<vector<int>> st;

public:
    void build(vector<int> &val)
    {
        st.resize(N + 1);
        int n = val.size() - 1;
        for (int i = 0; i <= N; i++)
            st[i].resize(n + 1);
        for (int i = 1; i <= n; i++)
            st[0][i] = val[i];
        for (int i = 1; i <= 20; i++)
            for (int j = 1; j + (1 << i) - 1 <= n; j++)
                st[i][j] = min(st[i - 1][j], st[i - 1][j + (1 << i - 1)]);
    }
    int query(int l, int r)
    {
        int X = lg[r - l + 1];
        return min(st[X][l], st[X][r - (1 << X) + 1]);
    }
};
class SA
{
    vector<int> rk, sa, cnt, height, oldrk, px, id;
    int n;
    bool cmp(int x, int y, int w)
    {
        return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];
    }
    RMQ st;

public:
    void build(string s)
    {
        int n = s.length(), m = 300;
        this->n = n;
        oldrk.resize(max(m + 1, 2 * n + 1));
        sa.resize(max(m + 1, n + 1));
        rk.resize(max(m + 1, n + 1));
        cnt.resize(max(m + 1, n + 1));
        height.resize(max(m + 1, n + 1));
        px.resize(max(m + 1, n + 1));
        id.resize(max(m + 1, n + 1));
        s = " " + s;
        for (int i = 1; i <= n; ++i)
            ++cnt[rk[i] = s[i]];
        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];
        for (int i = n; i >= 1; --i)
            sa[cnt[rk[i]]--] = i;
        for (int w = 1, p;; w <<= 1, m = p)
        {
            p = 0;
            for (int i = n; i > n - w; --i)
                id[++p] = i;
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id[++p] = sa[i] - w;
            fill(cnt.begin(), cnt.end(), 0);
            for (int i = 1; i <= n; ++i)
                ++cnt[px[i] = rk[id[i]]];
            for (int i = 1; i <= m; ++i)
                cnt[i] += cnt[i - 1];
            for (int i = n; i >= 1; --i)
                sa[cnt[px[i]]--] = id[i];
            copy(rk.begin(), rk.end(), oldrk.begin());
            p = 0;
            for (int i = 1; i <= n; ++i)
                rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p;
            if (p == n)
            {
                for (int i = 1; i <= n; ++i)
                    sa[rk[i]] = i;
                break;
            }
        }
        for (int i = 1, k = 0; i <= n; ++i)
        {
            if (rk[i] == 0)
                continue;
            if (k)
                --k;
            while (s[i + k] == s[sa[rk[i] - 1] + k])
                ++k;
            height[rk[i]] = k;
        }
        st.build(height);
    }
    int lcp(int x, int y)
    {
        if (x <= 0 || x > n || y <= 0 || y > n)
            return 0;
        int l = rk[x], r = rk[y];
        if (l > r)
            swap(l, r);
        return st.query(l + 1, r);
    }
};
class LongestCommon
{
    SA ord, rev;
    int n;

public:
    LongestCommon(string s)
    {
        this->n = s.length();
        ord.build(s);
        reverse(s.begin(), s.end());
        rev.build(s);
    }
    int lcp(int x, int y)
    {
        if (x <= 0 || x > n || y <= 0 || y > n)
            return 0;
        else
            return ord.lcp(x, y);
    }
    int lcs(int x, int y)
    {
        if (x <= 0 || x > n || y <= 0 || y > n)
            return 0;
        else
            return rev.lcp(n - x + 1, n - y + 1);
    }
};
int main()
{
    cin.tie(0)->sync_with_stdio(0);
    cin.exceptions(cin.failbit);
    cin.tie(NULL);
    cout.tie(NULL);
    for (int i = 2; i <= N; i++)
        lg[i] = lg[i >> 1] + 1;
    int t;
    string s;
    cin >> t;
    while (t--)
    {
        cin >> s;
        LongestCommon solve(s);
        int n = s.length();
        long long ans = 0;
        for (int i = 1; i <= n; i++)
        {
            ans += 1ll * i * (n - i + 1);
            for (int j = i; j + i <= n; j += i)
            {
                int suf = min(solve.lcs(j, j + i), i), pre = solve.lcp(j + 1, j + i + 1);
                long long s = 1ll * suf * (pre / i);
                pre %= i;
                if (suf == i)
                {
                    s++;
                    suf--;
                }
                int l = max(1, i - pre);
                if (l <= suf)
                    s += suf - l + 1;
                ans += s * i;
            }
        }
        cout << ans << "\n";
    }
    return 0;
}