思路:
不妨考虑生成函数,因为是求排列数,所以首先想到指数函数,保证操作完后保证)),设关于
))的操作数
))个,其中
))个可使
)),则我们只要保证对
))的最后一次操作为变为
,所以其生成函数为
即如果刚开始,则不对
进行操作也是合理的.
接下来就是将个多项式乘起来,显然直接乘会
,我们考虑利用
进行优化,显然进行
次
同样不行,我们考虑分治的方法进行
,具体结构可以联想
.
关于分治NTT的复杂度分析
看完题解的第一反应,不是和
差不多吗?
因为分治的结构类似于线段树,所以我们分治进行了层,可以证明的是,当前所有多项式最高项之和一定为
,则改成的复杂度为
,可以证明其复杂度是 在
附近,所以总复杂度为
在求出最后的多项式后,对于第个项,我们都用其系数乘上
,再求个就位结果,值得注意的是,对于第一项(即项数为 0),我们应该直接加上其系数(PS:但好像不存在这种特殊的数据)
代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <stack>
#include <bitset>
#include <vector>
#include <map>
#include <string>
#include <cstring>
#define fir first
#define sec second
using namespace std;
typedef long long ll;
const int maxn = 1e5+7;
const ll mod = 998244353;
int n,m,k;
int a[maxn];
int sum[maxn],ok[maxn];
ll fast(ll a,ll b) {
ll sum = 1;
while(b) {
if(b&1) sum = sum*a%mod;
b >>= 1;
a = a*a %mod;
}
return sum;
}
int g = 3;
//NTT
ll qpower(ll x, ll y) {
ll res = 1;
while (y) {
if (y & 1) (res *= x) %= mod;
(x *= x) %= mod;
y >>= 1;
}
return res;
}
ll inline inv(ll x) { return qpower(x, mod - 2); }
int rev[maxn<<1];
void NTT(ll *arr, int size, int type) {
rev[0] = 0;
for (int i = 1; i < size; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (size >> 1) : 0);
for (int i = 0; i < size; ++i)
if (rev[i] > i) std::swap(arr[i], arr[rev[i]]);
for (int len = 2; len <= size; len <<= 1) {
ll wn = qpower(g, (mod - 1) / len);
if (type == -1) wn = inv(wn);
for (int i = 0; i < size; i += len) {
ll w = 1;
for (int j = 0; j < (len >> 1); ++j, w = w * wn % mod) {
ll tmp1 = arr[i + j], tmp2 = arr[i + (len >> 1) + j] * w % mod;
arr[i + j] = tmp1 + tmp2;
arr[i + j + (len >> 1)] = tmp1 - tmp2;
if (arr[i + j] >= mod) arr[i + j] -= mod;
if (arr[i + j + (len >> 1)] < 0) arr[i + j + (len >> 1)] += mod;
}
}
}
if (type == -1) {
ll t = inv(size);
for (int i = 0; i < size; ++i)
arr[i] = arr[i] * t % mod;
}
}
int cnt;
ll F[300][maxn<<1];
int bitsize[maxn];
queue<int> q;
void del(int x) {q.push(x);}
int getcnt() {
if(q.size()) {
int x = q.front();
q.pop();
return x;
}
else return cnt++;
}
int cdq(int l,int r,int& num) {
if(l==r) {
num = getcnt();
F[num][sum[l]] = ok[l]*inv(sum[l])%mod;
F[num][0] = -F[num][sum[l]]+(a[l]==0? 1:0);
return sum[l];
}
int mid = (l+r) >>1,L,R,sz1,sz2;
sz1 = cdq(l,mid,L);
sz2 = cdq(mid+1,r,R);
NTT(F[L],bitsize[sz1+sz2],1);
NTT(F[R],bitsize[sz1+sz2],1);
for(int i=0;i<=bitsize[sz1+sz2];i++) {
F[L][i] = F[L][i] * F[R][i] %mod;
F[R][i] = 0;
}
NTT(F[L],bitsize[sz1+sz2],-1);
num = L;
del(R);
return sz1+sz2;
}
int main() {
scanf("%d%d%d",&n,&m,&k);
for(int bit = 1,i = 0;i<=m;i++) {
while(bit<=i) bit <<= 1;
bitsize[i] = bit;
}
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<=m;i++) {
int x,y;
scanf("%d%d",&x,&y);
sum[x]++;
if(y == 0) ok[x]++;
}
int rt;
cdq(1,n,rt);
ll ans = 0;
for(int i=0;i<=m;i++) {
if(i == 0) ans += F[rt][i];
else {
ans += F[rt][i]*fast(i,k)%mod;
}
ans %= mod;
}
printf("%lld\n", (ans+mod)%mod);
return 0;
} 
京公网安备 11010502036488号