任意模数多项式乘法
NTT
因为任意模数多项式乘法的系数可能超过了 所以一般的
就做不了。所以对于
的做法是找到
个模数,最后使用
合并一下,但是这里的合并中的答案可能超过
的值域,所以要用快速乘实现。
FFT
任意模数多项式乘法用 实现的问题主要是因为精度问题,这里使用了
次变化的
实现,采用了
作为底,那么一个多项式,我们根据
。那么
。就拆分成了系数比较小的四个多项式。最后合并一下就好了。
代码
这里只给出了 的代码实现。
#include<bits/stdc++.h>
using namespace std;
const int N = 4e5 + 100,base = 1 << 15;
#define db long double
#define ll long long
const db pi = acos(-1);
int read() {
int x = 0,f = 0;char ch = getchar();
while(!isdigit(ch)) {if(ch=='-')f=1;ch=getchar();}
while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
return f?-x:x;
}
struct comp{
db x,y;
comp(db a,db b) {x=a;y=b;}
comp() {x=y=0;}
comp operator + (const comp a) const {return comp(x+a.x,y+a.y);}
comp operator - (const comp a) const {return comp(x-a.x,y-a.y);}
comp operator * (const comp a) const {return comp(a.x*x-a.y*y,y*a.x+a.y*x);}
}A1[N],A2[N],B1[N],B2[N];
int limit = 1,L,r[N],n,m,p;
void fft(comp *a,int type) {
for(int i = 0;i < limit;i++) if(r[i] < i) swap(a[r[i]],a[i]);
for(int mid = 1;mid < limit;mid <<= 1) {
comp wn = comp(cos(1.0 * pi / mid),type * sin(1.0 * pi / mid));
for(int i = 0;i < limit;i += (mid << 1)) {
comp w = comp(1,0);
for(int j = 0;j < mid;j++,w = w * wn) {
comp x = a[i + j],y = w * a[i + j + mid];
a[i + j] = x + y;a[i + j + mid] = x - y;
}
}
}
if(type == -1) for(int i = 0;i < limit;i++) a[i].x = a[i].x / limit;
}
void Merge(comp *a,comp *b,int B,int *f) {
static comp g[N];
for(int i = 0;i < limit;i++) g[i] = a[i] * b[i];fft(g,-1);
for(int i = 0;i < limit;i++) f[i] = (f[i] + 1ll * B * ((ll)floor(g[i].x + 0.5) % p) % p) % p;
}
void MTT(comp *a,comp *b,comp *c,comp *d,int *f) {
fft(a,1);fft(b,1);fft(c,1);fft(d,1);
Merge(a,c,base * 1ll * base % p,f);Merge(a,d,base % p,f);
Merge(b,c,base % p,f);Merge(b,d,1,f);
}
int main() {
n = read();m = read();p = read();
for(int i = 0,x;i <= n;i++) {x = read();A1[i].x = x / base;B1[i].x = x % base;}
for(int i = 0,x;i <= m;i++) {x = read();A2[i].x = x / base;B2[i].x = x % base;}
while(limit <= n + m) limit <<= 1,L++;
for(int i = 0;i < limit;i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << L - 1);
static int Ans[N];memset(Ans,0,sizeof(Ans));
MTT(A1,B1,A2,B2,Ans);
for(int i = 0;i <= n + m;i++) printf("%d ",Ans[i]);
return 0 & printf("\n");
}
京公网安备 11010502036488号