一个长度为 \(n\) 字符集大小为 \(k\) 的字符串,它的回文串的个数是 \(k^{\lceil \frac{n}{2} \rceil}\)

发现根据题目里给的操作二, 可以生成 最小循环节的长度 个满足条件的字符串,

用这个长度的字符串拼出来长为 \(n\) 的字符串必须是回文串

设这个长度为 \(l\),满足这个长度(能拼成长为 \(n\) 的回文串)的字符串个数是 \(f[l]\) ,那么对答案的贡献就是 \(l \times f[l]\).

\(f\) 容斥计算一下即可。

#include<algorithm>
#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
LL n, k, tot, ans;
const int mod = 1e9 + 7, N = 2005;
LL w[N], f[N];
LL ksm(LL a, LL b, LL mod)
{
	LL res = 1;
	for (; b; b >>= 1, a = a * a % mod)
		if (b & 1)res = res * a % mod;
	return res;
}
int main()
{
	freopen("string.in", "r", stdin);
	freopen("string.out", "w", stdout);
	cin >> n >> k;
	for (int i = 1; i * i <= n; ++i)
		if (!(n % i))
		{
			w[++tot] = i;
			if (i * i != n)w[++tot] = n / i;
		}
	sort(w + 1, w + 1 + tot);
	for (int i = 1; i <= tot; ++i)
	{
		f[i] = ksm(k, (w[i] + 1) / 2, mod);
		for (int j = 1; j < i; ++j)
			if (!(w[i] % w[j]))f[i] = (f[i] - f[j] + mod) % mod;
		if (w[i] & 1)ans = (ans + f[i] * w[i]) % mod;
		else ans = (ans + f[i] * (w[i] / 2)) % mod;
	}
	cout << ans;
	fclose(stdin); fclose(stdout);
	return 0;
}