显然要是没有限制那就全都是\(1\)对吧,
所以考虑这些\(0, 1\)到底限制了啥。
首先看到\(border\), 容易想到当年写\(kmp\)求最短回文子串的那题。因为\(kmp\)的\(next\)数组实际上就是最长\(border\)。我们也在那题积累一个结论, 就是对于原来的串一个长度为\(k\)的\(border\), 原串一定是由长度为\(n -k\)的子串循环而成的,但是最后一部分并不一定是一个完整的循环节。(实际上你也可以看成第一个不一定是一个完整的循环节, 这一部分画图很容易看出来)。那么如果有一组\(0,1\)的位置距离恰好是\(n-k\),那显然就是不行的, 因为按照要求, 它们应该相等。显然, 如果这个距离是\(n-k\)的因数也是不行的。
然后一个简单的暴力就出来了, 暴力枚举所有\(0, 1\), 求出它们的距离, 然后取并集, 对于这个集合里的长度暴力枚举倍数来判断是否合法, 复杂度\(O(n^2+n\ln n)\)。
然后我们考虑优化这个做法, 优化这个我只知道有数据结构和多项式两条路,感觉数据结构不太行那就搞多项式。我们要求的是距离, 就设\(F(x) = \sum_{i=0}^{n-1}x^if_i\), 其中\(f_i\)表示距离为\(i\)的\(0,1\)点对有多少对。 然后设\(A(x)=\sum_{i=0}^{n-1}[s[i] =='0']x^i, B(x)=\sum_{i=0}^{n-1}[s[i]=='1']x^{-i}\),考虑怎么解决这个负数幂次, 让\(B(x)=x^{n}B(\frac{1}{x})\),最后求答案的时候直接位置加上\(n\)就好了。
#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
#define R register
#define LL long long
const int N = 2e6 + 10;
const int P = 998244353;
const int G = 3;
const int Gi = 332748118;
inline int read() {
int x = 0, f = 1; char a = getchar();
for(; a > '9' || a < '0'; a = getchar()) if(a == '-') f = -1;
for(; a >= '0' && a <= '9'; a = getchar()) x = x * 10 + a - '0';
return x * f;
}
inline int qpow(int x, int k) {
int res = 1;
while(k) {
if(k & 1) res = 1LL * res * x % P;
x = 1LL * x * x % P; k >>= 1;
} return res;
}
int bit, rev[N], lim = 1;
inline void NTT(int *A, int type = 1) {
for(R int i = 0; i < lim; i ++) if(i > rev[i]) swap(A[i], A[rev[i]]);
for(R int dep = 1; dep < lim; dep <<= 1) {
int Wn = qpow(type == 1 ? G : Gi, (P - 1) / (dep << 1));
for(R int j = 0, len = dep << 1; j < lim; j += len) {
int w = 1;
for(R int k = 0; k < dep; k ++, w = 1LL * w * Wn % P) {
int x = A[j + k], y = 1LL * A[j + k + dep] * w % P;
A[j + k] = (x + y) % P;
A[j + k + dep] = (x - y + P) % P;
}
}
}
if(type == -1) {
int inv = qpow(lim, P - 2);
for(R int i = 0; i < lim; i ++) A[i] = 1LL * A[i] * inv % P;
}
}
int n;
char s[N];
int a[N], b[N];
int main() {
//freopen(".in", "r", stdin);
//freopen(".out", "w", stdout);
scanf("%s", s); n = strlen(s);
while(lim < (n << 1)) lim <<= 1, bit ++;
for(R int i = 0; i < lim; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
for(R int i = 0; i < n; i ++) a[i] = (s[i] == '0');
for(R int i = 0; i < n; i ++) b[i] = (s[n - i - 1] == '1');
NTT(a); NTT(b);
for(R int i = 0; i < lim; i ++) a[i] = 1LL * a[i] * b[i] % P;
NTT(a, -1);
LL ans = 1LL * n * n;
for(R int i = 1; i < n; i ++) {
int f = 1;
for(R int j = i; j < n; j += i) if(a[n - 1 - j] | a[n - 1 + j]) {f = 0; break; }
if(f) ans ^= 1LL * (n - i) * (n - i);
}
printf("%lld\n", ans);
return 0;
}