一个长度为 \(n\) 字符集大小为 \(k\) 的字符串,它的回文串的个数是 \(k^{\lceil \frac{n}{2} \rceil}\)
发现根据题目里给的操作二, 可以生成 最小循环节的长度 个满足条件的字符串,
用这个长度的字符串拼出来长为 \(n\) 的字符串必须是回文串
设这个长度为 \(l\),满足这个长度(能拼成长为 \(n\) 的回文串)的字符串个数是 \(f[l]\) ,那么对答案的贡献就是 \(l \times f[l]\).
\(f\) 容斥计算一下即可。
#include<algorithm>
#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
LL n, k, tot, ans;
const int mod = 1e9 + 7, N = 2005;
LL w[N], f[N];
LL ksm(LL a, LL b, LL mod)
{
LL res = 1;
for (; b; b >>= 1, a = a * a % mod)
if (b & 1)res = res * a % mod;
return res;
}
int main()
{
freopen("string.in", "r", stdin);
freopen("string.out", "w", stdout);
cin >> n >> k;
for (int i = 1; i * i <= n; ++i)
if (!(n % i))
{
w[++tot] = i;
if (i * i != n)w[++tot] = n / i;
}
sort(w + 1, w + 1 + tot);
for (int i = 1; i <= tot; ++i)
{
f[i] = ksm(k, (w[i] + 1) / 2, mod);
for (int j = 1; j < i; ++j)
if (!(w[i] % w[j]))f[i] = (f[i] - f[j] + mod) % mod;
if (w[i] & 1)ans = (ans + f[i] * w[i]) % mod;
else ans = (ans + f[i] * (w[i] / 2)) % mod;
}
cout << ans;
fclose(stdin); fclose(stdout);
return 0;
}