思路:
不妨考虑生成函数,因为是求排列数,所以首先想到指数函数,保证操作完后保证)),设关于))的操作数))个,其中))个可使)),则我们只要保证对))的最后一次操作为变为,所以其生成函数为
即如果刚开始,则不对进行操作也是合理的.
接下来就是将个多项式乘起来,显然直接乘会,我们考虑利用进行优化,显然进行次同样不行,我们考虑分治的方法进行,具体结构可以联想.
关于分治NTT的复杂度分析
看完题解的第一反应,不是和差不多吗?
因为分治的结构类似于线段树,所以我们分治进行了层,可以证明的是,当前所有多项式最高项之和一定为,则改成的复杂度为
,可以证明其复杂度是 在 附近,所以总复杂度为
在求出最后的多项式后,对于第个项,我们都用其系数乘上,再求个就位结果,值得注意的是,对于第一项(即项数为 0),我们应该直接加上其系数(PS:但好像不存在这种特殊的数据)
代码:
#include <iostream> #include <cstdio> #include <algorithm> #include <queue> #include <stack> #include <bitset> #include <vector> #include <map> #include <string> #include <cstring> #define fir first #define sec second using namespace std; typedef long long ll; const int maxn = 1e5+7; const ll mod = 998244353; int n,m,k; int a[maxn]; int sum[maxn],ok[maxn]; ll fast(ll a,ll b) { ll sum = 1; while(b) { if(b&1) sum = sum*a%mod; b >>= 1; a = a*a %mod; } return sum; } int g = 3; //NTT ll qpower(ll x, ll y) { ll res = 1; while (y) { if (y & 1) (res *= x) %= mod; (x *= x) %= mod; y >>= 1; } return res; } ll inline inv(ll x) { return qpower(x, mod - 2); } int rev[maxn<<1]; void NTT(ll *arr, int size, int type) { rev[0] = 0; for (int i = 1; i < size; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (size >> 1) : 0); for (int i = 0; i < size; ++i) if (rev[i] > i) std::swap(arr[i], arr[rev[i]]); for (int len = 2; len <= size; len <<= 1) { ll wn = qpower(g, (mod - 1) / len); if (type == -1) wn = inv(wn); for (int i = 0; i < size; i += len) { ll w = 1; for (int j = 0; j < (len >> 1); ++j, w = w * wn % mod) { ll tmp1 = arr[i + j], tmp2 = arr[i + (len >> 1) + j] * w % mod; arr[i + j] = tmp1 + tmp2; arr[i + j + (len >> 1)] = tmp1 - tmp2; if (arr[i + j] >= mod) arr[i + j] -= mod; if (arr[i + j + (len >> 1)] < 0) arr[i + j + (len >> 1)] += mod; } } } if (type == -1) { ll t = inv(size); for (int i = 0; i < size; ++i) arr[i] = arr[i] * t % mod; } } int cnt; ll F[300][maxn<<1]; int bitsize[maxn]; queue<int> q; void del(int x) {q.push(x);} int getcnt() { if(q.size()) { int x = q.front(); q.pop(); return x; } else return cnt++; } int cdq(int l,int r,int& num) { if(l==r) { num = getcnt(); F[num][sum[l]] = ok[l]*inv(sum[l])%mod; F[num][0] = -F[num][sum[l]]+(a[l]==0? 1:0); return sum[l]; } int mid = (l+r) >>1,L,R,sz1,sz2; sz1 = cdq(l,mid,L); sz2 = cdq(mid+1,r,R); NTT(F[L],bitsize[sz1+sz2],1); NTT(F[R],bitsize[sz1+sz2],1); for(int i=0;i<=bitsize[sz1+sz2];i++) { F[L][i] = F[L][i] * F[R][i] %mod; F[R][i] = 0; } NTT(F[L],bitsize[sz1+sz2],-1); num = L; del(R); return sz1+sz2; } int main() { scanf("%d%d%d",&n,&m,&k); for(int bit = 1,i = 0;i<=m;i++) { while(bit<=i) bit <<= 1; bitsize[i] = bit; } for(int i=1;i<=n;i++) scanf("%d",&a[i]); for(int i=1;i<=m;i++) { int x,y; scanf("%d%d",&x,&y); sum[x]++; if(y == 0) ok[x]++; } int rt; cdq(1,n,rt); ll ans = 0; for(int i=0;i<=m;i++) { if(i == 0) ans += F[rt][i]; else { ans += F[rt][i]*fast(i,k)%mod; } ans %= mod; } printf("%lld\n", (ans+mod)%mod); return 0; }