就是:
多项式求逆加任意模数ntt。
通常处理任意模数ntt有两种方式,这里选择其中一种。
吸氧过的。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
using namespace std;
const int mod=1e9+7;
const int p=mod;
const int pm=32768;
const int maxn=4e5+100;
const double pi=acos(-1.0);
struct Complex
{
double x,y;
Complex(double xx=0.0,double yy=0.0)
{
x=xx,y=yy;
}
Complex operator - (const Complex &b) const
{
return Complex(x-b.x,y-b.y);
}
Complex operator + (const Complex &b) const
{
return Complex(x+b.x,y+b.y);
}
Complex operator * (const Complex &b) const
{
return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
}
};
int fi[maxn],ans1[maxn],ans2[maxn];
int n;
Complex a1[maxn],b1[maxn],a2[maxn],b2[maxn],ww[maxn],ta[maxn];
int b[maxn],invb[maxn];
void fft(Complex *x,int len,int f)
{
for(int i=0;i<len;i++)
if(i<fi[i]) swap(x[i],x[fi[i]]);
for(int i=1;i<len;i<<=1)
{
for(int r=i<<1,j=0;j<len;j+=r)
{
for(int k=0;k<i;k++)
{
Complex w=ww[len/i*k];
w.y*=f;
Complex xx=x[j+k],yy=w*x[j+i+k];
x[j+k]=xx+yy;
x[j+i+k]=xx-yy;
}
}
}
if(f==-1)
for(int i=0;i<len;i++)
x[i].x/=len;
}
void get(Complex *x,Complex *y,int len,int pm,int *ans)
{
for(int i=0;i<len;i++) ta[i]=x[i]*y[i];
fft(ta,len,-1);
for(int i=0;i<len;i++)
ans[i]=(ans[i]+(ll)(ta[i].x+0.5)%p*1ll*pm)%p;
}
void mtt(int *a,int *b,int n,int *ans)
{
int len=1,cnt=0;
while(len<=((n-1)<<1)) len<<=1,cnt++;
for(int i=0;i<len;i++)
{
fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
ww[i]=Complex(cos(pi/len*i),sin(pi/len*i));
a1[i]=b1[i]=a2[i]=b2[i]=Complex(0.0,0.0);
ans[i]=0;
}
for(int i=0;i<n;i++)
{
a1[i].x=a[i]/pm,b1[i].x=a[i]%pm;
a2[i].x=b[i]/pm,b2[i].x=b[i]%pm;
}
fft(a1,len,1);
fft(b1,len,1);
fft(a2,len,1);
fft(b2,len,1);
get(a1,a2,len,pm*pm%p,ans);
get(a1,b2,len,pm%p,ans);
get(a2,b1,len,pm%p,ans);
get(b1,b2,len,1,ans);
}
int mypow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;
b>>=1;
}
return ans;
}
void inv(int n,int *a,int *b)
{
if(n==1)
{
b[0]=mypow(a[0],p-2);
return ;
}
inv((n+1)>>1,a,b);
mtt(a,b,n,ans1);
mtt(ans1,b,n,ans2);
for(int i=0;i<n;i++)
b[i]=(2ll*b[i]-ans2[i]+mod)%mod;
}
int main(void)
{
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)
scanf("%d",&b[i]);
inv(n,b,invb);
for(int i=0;i<n;i++)
printf("%d ",invb[i]);
putchar('\n');
return 0;
}
终于解决了一个困惑我的问题。
见注释。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
using namespace std;
const int mod=1e9+7;
const int p=mod;
const int pm=32768;
const int maxn=4e5+100;
const double pi=acos(-1.0);
struct Complex
{
double x,y;
Complex(double xx=0.0,double yy=0.0)
{
x=xx,y=yy;
}
Complex operator - (const Complex &b) const
{
return Complex(x-b.x,y-b.y);
}
Complex operator + (const Complex &b) const
{
return Complex(x+b.x,y+b.y);
}
Complex operator * (const Complex &b) const
{
return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
}
};
int fi[maxn],ans1[maxn],ans2[maxn];
int n;
Complex a1[maxn],b1[maxn],a2[maxn],b2[maxn],ww[maxn],ta[maxn];
int b[maxn],invb[maxn];
void fft(Complex *x,int len,int f)
{
for(int i=0;i<len;i++)
if(i<fi[i]) swap(x[i],x[fi[i]]);
for(int i=1;i<len;i<<=1)
{
for(int r=i<<1,j=0;j<len;j+=r)
{
for(int k=0;k<i;k++)
{
Complex w=ww[len/i*k];
w.y*=f;
Complex xx=x[j+k],yy=w*x[j+i+k];
x[j+k]=xx+yy;
x[j+i+k]=xx-yy;
}
}
}
if(f==-1)
for(int i=0;i<len;i++)
x[i].x/=len;
}
void get(Complex *x,Complex *y,int len,int pm,int *ans)
{
for(int i=0;i<len;i++) ta[i]=x[i]*y[i];
fft(ta,len,-1);
for(int i=0;i<len;i++)
ans[i]=(ans[i]+(ll)(ta[i].x+0.5)%p*1ll*pm)%p;
}
void mtt(int *a,int *b,int n,int *ans)
{
int len=1,cnt=0;
while(len<=((n-1)<<1)) len<<=1,cnt++;
for(int i=0;i<len;i++)
{
fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
ww[i]=Complex(cos(pi/len*i),sin(pi/len*i));
a1[i]=b1[i]=a2[i]=b2[i]=Complex(0.0,0.0);
ans[i]=0;
}
for(int i=0;i<n;i++)
{
a1[i].x=a[i]/pm,b1[i].x=a[i]%pm;
a2[i].x=b[i]/pm,b2[i].x=b[i]%pm;
}
fft(a1,len,1);
fft(b1,len,1);
fft(a2,len,1);
fft(b2,len,1);
get(a1,a2,len,pm*pm%p,ans);
get(a1,b2,len,pm%p,ans);
get(a2,b1,len,pm%p,ans);
get(b1,b2,len,1,ans);
}
int mypow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;
b>>=1;
}
return ans;
}
void inv(int n,int *a,int *b)
{
if(n==1)
{
b[0]=mypow(a[0],p-2);
return ;
}
inv((n+1)>>1,a,b);
mtt(a,b,n,ans1);
//这里应该这样写,注意观察原式
//注意比较与对998244353取模时,用ntt做法的不同。
for(int i=0;i<n;i++)
ans1[i]=p-ans1[i];
ans1[0]=(ans1[0]+2)%p;
mtt(ans1,b,n,ans2);
for(int i=0;i<n;i++)
b[i]=ans2[i];
}
int main(void)
{
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)
scanf("%d",&b[i]);
inv(n,b,invb);
for(int i=0;i<n;i++)
printf("%d ",invb[i]);
putchar('\n');
return 0;
}