D Mi Re Do Si La? So Fa!
题意:给定串 ,询问其每个子串的完整循环节长度和。例如串 aaaa
的完整循环节长度有 三种。。
解法:显然每个子串 都有长度为 的循环节。考虑循环节长度不为子串本身的循环节长度:
其中红色块代表长度为 的循环节,那么如果串中存在子串有长度为 的完整循环节,那么对于该串覆盖的任意的分段点(如上图中的 ),则都有这若干处分段点开始(图中仅三处,可以更长)的总 长度与总 的长度和大于等于 。
枚举长度 再枚举分段点的位置 ,设 为 与 处的 长度, 为 处与 的 长度,统计起始位置在 的所有子串,则满足完整循环节长度为 的串个数等于从 处向前选择非空的一段、从 处再向后选择长度不超过 的一段,最终将两段拼接起来长度为 的倍数的方案数,即求 ,分类讨论即可。
最终总时间复杂度为 。
#include <bits/stdc++.h>
using namespace std;
const int N = 1'000'000;
int lg[N + 5];
class RMQ
{
const int N = 20;
vector<vector<int>> st;
public:
void build(vector<int> &val)
{
st.resize(N + 1);
int n = val.size() - 1;
for (int i = 0; i <= N; i++)
st[i].resize(n + 1);
for (int i = 1; i <= n; i++)
st[0][i] = val[i];
for (int i = 1; i <= 20; i++)
for (int j = 1; j + (1 << i) - 1 <= n; j++)
st[i][j] = min(st[i - 1][j], st[i - 1][j + (1 << i - 1)]);
}
int query(int l, int r)
{
int X = lg[r - l + 1];
return min(st[X][l], st[X][r - (1 << X) + 1]);
}
};
class SA
{
vector<int> rk, sa, cnt, height, oldrk, px, id;
int n;
bool cmp(int x, int y, int w)
{
return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];
}
RMQ st;
public:
void build(string s)
{
int n = s.length(), m = 300;
this->n = n;
oldrk.resize(max(m + 1, 2 * n + 1));
sa.resize(max(m + 1, n + 1));
rk.resize(max(m + 1, n + 1));
cnt.resize(max(m + 1, n + 1));
height.resize(max(m + 1, n + 1));
px.resize(max(m + 1, n + 1));
id.resize(max(m + 1, n + 1));
s = " " + s;
for (int i = 1; i <= n; ++i)
++cnt[rk[i] = s[i]];
for (int i = 1; i <= m; ++i)
cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; --i)
sa[cnt[rk[i]]--] = i;
for (int w = 1, p;; w <<= 1, m = p)
{
p = 0;
for (int i = n; i > n - w; --i)
id[++p] = i;
for (int i = 1; i <= n; ++i)
if (sa[i] > w)
id[++p] = sa[i] - w;
fill(cnt.begin(), cnt.end(), 0);
for (int i = 1; i <= n; ++i)
++cnt[px[i] = rk[id[i]]];
for (int i = 1; i <= m; ++i)
cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; --i)
sa[cnt[px[i]]--] = id[i];
copy(rk.begin(), rk.end(), oldrk.begin());
p = 0;
for (int i = 1; i <= n; ++i)
rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p;
if (p == n)
{
for (int i = 1; i <= n; ++i)
sa[rk[i]] = i;
break;
}
}
for (int i = 1, k = 0; i <= n; ++i)
{
if (rk[i] == 0)
continue;
if (k)
--k;
while (s[i + k] == s[sa[rk[i] - 1] + k])
++k;
height[rk[i]] = k;
}
st.build(height);
}
int lcp(int x, int y)
{
if (x <= 0 || x > n || y <= 0 || y > n)
return 0;
int l = rk[x], r = rk[y];
if (l > r)
swap(l, r);
return st.query(l + 1, r);
}
};
class LongestCommon
{
SA ord, rev;
int n;
public:
LongestCommon(string s)
{
this->n = s.length();
ord.build(s);
reverse(s.begin(), s.end());
rev.build(s);
}
int lcp(int x, int y)
{
if (x <= 0 || x > n || y <= 0 || y > n)
return 0;
else
return ord.lcp(x, y);
}
int lcs(int x, int y)
{
if (x <= 0 || x > n || y <= 0 || y > n)
return 0;
else
return rev.lcp(n - x + 1, n - y + 1);
}
};
int main()
{
cin.tie(0)->sync_with_stdio(0);
cin.exceptions(cin.failbit);
cin.tie(NULL);
cout.tie(NULL);
for (int i = 2; i <= N; i++)
lg[i] = lg[i >> 1] + 1;
int t;
string s;
cin >> t;
while (t--)
{
cin >> s;
LongestCommon solve(s);
int n = s.length();
long long ans = 0;
for (int i = 1; i <= n; i++)
{
ans += 1ll * i * (n - i + 1);
for (int j = i; j + i <= n; j += i)
{
int suf = min(solve.lcs(j, j + i), i), pre = solve.lcp(j + 1, j + i + 1);
long long s = 1ll * suf * (pre / i);
pre %= i;
if (suf == i)
{
s++;
suf--;
}
int l = max(1, i - pre);
if (l <= suf)
s += suf - l + 1;
ans += s * i;
}
}
cout << ans << "\n";
}
return 0;
}