虽然这题暴力即可,但这题一看就有组合DP做法,而且系数固定,显然可以用多项式科技优化,遂让AI写了一下此做法。
这是一个基于容斥原理结合生成函数的组合计数问题,使用NTT进行优化。
算法思路
- 容斥原理转化:我们要计算没有相邻字符相等的排列数。直接计算比较困难,我们使用容斥原理。对于每种字符
,假设它的出现次数为
。我们可以枚举该字符中有
对相邻字符是相等的(即“坏”连接)。如果选定了
对相邻相等,相当于将这
个字符分成了
个“块”。
- 生成函数构建:对于每种字符
,设我们要分成
个块(即选了
个坏连接)。将
个物品分成
个有序非空块的方案数是插板法
。在最终的排列中,我们有总共
个块。这些块的排列方案数为
。但是,同一种字符的
个块在上述插板法中是有序的(第1块、第2块...),而在混合排列时,为了还原回原字符串的顺序,这
个块的相对顺序必须保持不变。因此,对于每种字符,我们需要除以
。综合起来,对于字符
,如果分成
块,其对答案的贡献项(包含容斥系数 )的生成函数项为:
。
- 注意:容斥系数是
。坏连接数
。我们可以将
提出来,或者直接在多项式项中包含
。 实际上,若令多项式项为
,则最终结果需乘以
。因此,对于每种字符 ,构造多项式:
- 多项式乘法:我们需要计算所有字符对应多项式的乘积 。由于字符集大小
是常数,我们可以使用分治法(类似归并排序或哈夫曼树的思路)两两合并多项式。
- 计算答案:设
。最终答案为:
时间复杂度
- 设字符串长度为
。
- 构造所有多项式的时间为
。
- 多项式乘法:使用分治策略合并
个多项式。递归深度为
。每一层的多项式次数之和约为
,NTT的时间复杂度为
。
- 因此,总时间复杂度为
。
C++ 代码
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
// 模数与原根
const int MOD = 998244353;
const int G = 3;
// 快速幂
long long power(long long base, long long exp) {
long long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
// 模逆元
long long modInverse(long long n) {
return power(n, MOD - 2);
}
// NTT (快速数论变换)
void ntt(vector<long long>& a, bool invert) {
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
long long wlen = power(G, (MOD - 1) / len);
if (invert) wlen = modInverse(wlen);
for (int i = 0; i < n; i += len) {
long long w = 1;
for (int j = 0; j < len / 2; j++) {
long long u = a[i + j], v = (a[i + j + len / 2] * w) % MOD;
a[i + j] = (u + v < MOD ? u + v : u + v - MOD);
a[i + j + len / 2] = (u - v >= 0 ? u - v : u - v + MOD);
w = (w * wlen) % MOD;
}
}
}
if (invert) {
long long n_inv = modInverse(n);
for (long long& x : a)
x = (x * n_inv) % MOD;
}
}
// 多项式乘法
vector<long long> multiply(const vector<long long>& a, const vector<long long>& b) {
if (a.empty() || b.empty()) return {};
int sz = a.size() + b.size() - 1;
int n = 1;
while (n < sz) n <<= 1;
vector<long long> fa(a.begin(), a.end());
vector<long long> fb(b.begin(), b.end());
fa.resize(n);
fb.resize(n);
ntt(fa, false);
ntt(fb, false);
for (int i = 0; i < n; i++)
fa[i] = (fa[i] * fb[i]) % MOD;
ntt(fa, true);
fa.resize(sz);
return fa;
}
// 组合数预处理
vector<long long> fact, invFact;
void prepareCombinatorics(int n) {
fact.resize(n + 1);
invFact.resize(n + 1);
fact[0] = 1;
for (int i = 1; i <= n; i++) fact[i] = (fact[i - 1] * i) % MOD;
invFact[n] = modInverse(fact[n]);
for (int i = n - 1; i >= 0; i--) invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
long long nCr(int n, int r) {
if (r < 0 || r > n) return 0;
return fact[n] * invFact[r] % MOD * invFact[n - r] % MOD;
}
// 分治合并多项式
// 时间复杂度: O(N log N log 26),其中 N 为字符串长度
vector<long long> solve(int l, int r, const vector<vector<long long>>& polys) {
if (l == r) return polys[l];
int mid = (l + r) / 2;
return multiply(solve(l, mid, polys), solve(mid + 1, r, polys));
}
int main() {
// 优化输入输出
ios_base::sync_with_stdio(false);
cin.tie(NULL);
string s;
if (!(cin >> s)) return 0;
int n = s.length();
if (n == 0) {
cout << 0 << endl;
return 0;
}
// 统计字符频率
int counts[26] = {0};
for (char c : s) counts[c - 'a']++;
prepareCombinatorics(n);
vector<vector<long long>> polys;
for (int i = 0; i < 26; i++) {
if (counts[i] == 0) continue;
int c = counts[i];
// 构造多项式 A_i(x)
// A_i(x) = sum_{j=1}^{c} (-1)^j / j! * C(c-1, j-1) * x^j
// 大小为 c+1,下标从 0 到 c
vector<long long> poly(c + 1, 0);
for (int j = 1; j <= c; j++) {
long long term = invFact[j]; // 1/j!
term = (term * nCr(c - 1, j - 1)) % MOD;
if (j % 2 == 1) {
// (-1)^j
term = (MOD - term) % MOD;
}
poly[j] = term;
}
polys.push_back(poly);
}
if (polys.empty()) {
cout << 0 << endl;
return 0;
}
// 分治乘法合并所有多项式
vector<long long> P = solve(0, polys.size() - 1, polys);
// 计算最终答案
// Ans = (-1)^n * sum_{k=0}^n k! * coeff[k]
long long ans = 0;
for (int k = 0; k < P.size(); k++) {
if (P[k] == 0) continue;
long long term = (P[k] * fact[k]) % MOD;
ans = (ans + term) % MOD;
}
// 乘以 (-1)^n
if (n % 2 == 1) {
ans = (MOD - ans) % MOD;
}
cout << ans << endl;
return 0;
}

京公网安备 11010502036488号