任意模数多项式乘法

NTT

因为任意模数多项式乘法的系数可能超过了 所以一般的 就做不了。所以对于 的做法是找到 个模数,最后使用 合并一下,但是这里的合并中的答案可能超过 的值域,所以要用快速乘实现。

FFT

任意模数多项式乘法用 实现的问题主要是因为精度问题,这里使用了 次变化的 实现,采用了 作为底,那么一个多项式,我们根据 。那么 。就拆分成了系数比较小的四个多项式。最后合并一下就好了。

代码

这里只给出了 的代码实现。

#include<bits/stdc++.h>
using namespace std;
const int N = 4e5 + 100,base = 1 << 15;
#define db long double 
#define ll long long
const db pi = acos(-1);
int read() {
    int x = 0,f = 0;char ch = getchar();
    while(!isdigit(ch)) {if(ch=='-')f=1;ch=getchar();}
    while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
    return f?-x:x;
}
struct comp{
    db x,y;
    comp(db a,db b) {x=a;y=b;}
    comp() {x=y=0;}
    comp operator + (const comp a) const {return comp(x+a.x,y+a.y);}
    comp operator - (const comp a) const {return comp(x-a.x,y-a.y);}
    comp operator * (const comp a) const {return comp(a.x*x-a.y*y,y*a.x+a.y*x);}
}A1[N],A2[N],B1[N],B2[N];
int limit = 1,L,r[N],n,m,p;
void fft(comp *a,int type) {
    for(int i = 0;i < limit;i++) if(r[i] < i) swap(a[r[i]],a[i]);
    for(int mid = 1;mid < limit;mid <<= 1) {
        comp wn = comp(cos(1.0 * pi / mid),type * sin(1.0 * pi / mid));
        for(int i = 0;i < limit;i += (mid << 1)) {
            comp w = comp(1,0);
            for(int j = 0;j < mid;j++,w = w * wn) {
                comp x = a[i + j],y = w * a[i + j + mid];
                a[i + j] = x + y;a[i + j + mid] = x - y;
            }
        }
    }
    if(type == -1) for(int i = 0;i < limit;i++) a[i].x = a[i].x / limit;
}
void Merge(comp *a,comp *b,int B,int *f) {
    static comp g[N];
    for(int i = 0;i < limit;i++) g[i] = a[i] * b[i];fft(g,-1);
    for(int i = 0;i < limit;i++) f[i] = (f[i] + 1ll * B * ((ll)floor(g[i].x + 0.5) % p) % p) % p;
}
void MTT(comp *a,comp *b,comp *c,comp *d,int *f) {
    fft(a,1);fft(b,1);fft(c,1);fft(d,1);
    Merge(a,c,base * 1ll * base % p,f);Merge(a,d,base % p,f);
    Merge(b,c,base % p,f);Merge(b,d,1,f);
}
int main() {
    n = read();m = read();p = read();
    for(int i = 0,x;i <= n;i++) {x = read();A1[i].x = x / base;B1[i].x = x % base;} 
    for(int i = 0,x;i <= m;i++) {x = read();A2[i].x = x / base;B2[i].x = x % base;}
    while(limit <= n + m) limit <<= 1,L++;
    for(int i = 0;i < limit;i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << L - 1);
    static int Ans[N];memset(Ans,0,sizeof(Ans));
    MTT(A1,B1,A2,B2,Ans);
    for(int i = 0;i <= n + m;i++) printf("%d ",Ans[i]);
    return 0 & printf("\n");
}