注意:这里多项式快速幂,多项式的幂,幂可以直接对系数的模mod取模。
即,m%=mod即可。
但是在处理的过程中 要计算 A0m,这里 A0m%mod = A0m%(mod-1)%mod
但是不知道为什么。。。我的代码跑的比别人的慢好多。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
using namespace std;
const int maxn=8e5+100;
const int p=998244353;
const int mod=998244353;
const int g=3;
int fi[maxn];
int a[maxn],inva[maxn],da[maxn],lna[maxn],c[maxn],iinv[maxn];
int f[maxn];
int n;
ll m=0,mm=0;
void get_inv(int n)
{
iinv[1]=1;
for(int i=2;i<=n;i++)
iinv[i]=1ll*(mod-mod/i)*iinv[mod%i]%mod;
}
int mypow(int a,int b)
{
if(b<0) return mypow(mypow(a,p-2),-b);
int ans=1;
while(b)
{
if(b&1) ans=1ll*ans*a%p;
a=1ll*a*a%p;
b>>=1;
}
return ans%p;
}
int getlen(int n,int m)//n,m是次数
{
int len=1,cnt=0;
while(len<=n+m) len<<=1,cnt++;
for(int i=0;i<len;i++)
fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
return len;
}
void ntt(int *x,int len,int f)
{
for(int i=0;i<len;i++)
if(i<fi[i]) swap(x[i],x[fi[i]]);
for(int i=1;i<len;i<<=1)
{
int r=i<<1;
int wn=mypow(g,f*(p-1)/r);
for(int j=0;j<len;j+=r)
{
int w=1;
for(int k=0;k<i;k++)
{
int xx=x[j+k],yy=1ll*w*x[j+i+k]%p;
x[j+k]=(xx+yy)%p;
x[j+i+k]=((xx-yy)%p+p)%p;
w=1ll*w*wn%p;
}
}
}
if(f==-1)
{
int invn=mypow(len,p-2);
for(int i=0;i<len;i++)
x[i]=1ll*x[i]*invn%p;
}
}
void inv(int n,int *a,int *b)//n是项数
{
if(n==1)
{
b[0]=mypow(a[0],p-2);
return ;
}
inv((n+1)>>1,a,b);
int len=1,cnt=0;
while(len<=(n<<1)) len<<=1,cnt++;
for(int i=0;i<len;i++)
{
fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
c[i]=(i<n?a[i]:0);
b[i]=(i<n?b[i]:0);
}
ntt(c,len,1);
ntt(b,len,1);
for(int i=0;i<len;i++)
b[i]=(2-(ll)c[i]*b[i]%p+p)%p*b[i]%p;
ntt(b,len,-1);
for(int i=n;i<len;i++)
b[i]=0;
}
void dao(int *a,int *b,int n,int len)
{
for(int i=1;i<n;i++)
b[i-1]=(ll)i*a[i]%mod;
for(int i=n-1;i<len;i++)
b[i]=0;
}
void ji(int *a,int n,int len)
{
for(int i=n-2;i>=0;i--)
a[i+1]=(ll)a[i]*iinv[i+1]%mod;
a[0]=0;
for(int i=n;i<len;i++)
a[i]=0;
}
void ln(int *a,int n)//n是项数
{
int len=getlen(n-2,n-1);
dao(a,da,n,len);
inv(n-1,a,inva);
len=getlen(n-2,n-1);
ntt(da,len,1);
ntt(inva,len,1);
for(int i=0;i<len;i++) lna[i]=(ll)da[i]*inva[i]%mod;
ntt(lna,len,-1);
ji(lna,n,len);
}
void eexp(int *f,int *a,int n)//n是项数
{
if(n==1)
{
a[0]=1;
return ;
}
eexp(f,a,(n+1)>>1);
ln(a,n);
int len=getlen(n-1,n-1);
for(int i=0;i<n;i++)
lna[i]=(-lna[i]+f[i]+mod)%mod;
lna[0]=(lna[0]+1)%mod;
ntt(a,len,1);
ntt(lna,len,1);
for(int i=0;i<len;i++) a[i]=1ll*a[i]*lna[i]%mod;
ntt(a,len,-1);
for(int i=n;i<len;i++)
a[i]=0;
}
int main(void)
{
scanf("%d%*c",&n);
get_inv(n);
char ch;
bool flag=false;
while(isdigit(ch=getchar()))
{
m=m*10+ch-'0';
mm=mm*10+ch-'0';
if(m>=mod) flag=true;
m%=mod;
mm%=(mod-1);
}
for(int i=0;i<n;i++)
scanf("%d",&f[i]);
int k=0;
while(k<n&&f[k]==0) k++;
if(1ll*k*m>=n||(flag&&k))
{
for(int i=0;i<n;i++)
printf("0 ");
putchar('\n');
return 0;
}
n-=k;
for(int i=0;i<n;i++) f[i]=f[i+k];
for(int i=n;i<n+k;i++) f[i]=0;
int f0=f[0],invf=mypow(f[0],mod-2);
for(int i=0;i<n;i++) f[i]=1ll*f[i]*invf%mod;
ln(f,n);
for(int i=0;i<n;i++) f[i]=m*lna[i]%mod,lna[i]=0;
eexp(f,a,n);
int f0mm=mypow(f0,mm);
for(int i=0;i<n;i++) a[i]=1ll*a[i]*f0mm%mod;
int now=k*m;
n+=k;
for(int i=n-1;i>=now;i--) a[i]=a[i-now];
for(int i=now-1;i>=0;i--) a[i]=0;
for(int i=0;i<n;i++) printf("%d ",a[i]);
putchar('\n');
return 0;
}
又来补充一种新的解法:
常数还是巨大,吸氧才过。
考虑为什么要保证 A0=1 --------为了使 ln(A)0 = 0 (对多项式A取对数的常数项)。
为什么要使 ln(A)0=0 --------为了计算exp时,递归到尽头时,使B0=1。(使exp后常数项为 e0=1)
我们考虑直接求出 B=Ak 的常数项作为 b [ 0 ] 来用,那么他是多少呢?------ A0k
考虑 A0≠0 和 A0=0 两种情况来求解即可。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
using namespace std;
const int maxn=8e5+100;
const int p=998244353;
const int mod=998244353;
const int g=3;
int fi[maxn];
int a[maxn],inva[maxn],da[maxn],lna[maxn],c[maxn],iinv[maxn];
int f[maxn];
int n;
ll m=0,mm=0;
void get_inv(int n)
{
iinv[1]=1;
for(int i=2;i<=n;i++)
iinv[i]=1ll*(mod-mod/i)*iinv[mod%i]%mod;
}
int mypow(int a,int b)
{
if(b<0) return mypow(mypow(a,p-2),-b);
int ans=1;
while(b)
{
if(b&1) ans=1ll*ans*a%p;
a=1ll*a*a%p;
b>>=1;
}
return ans%p;
}
int getlen(int n,int m)//n,m是次数
{
int len=1,cnt=0;
while(len<=n+m) len<<=1,cnt++;
for(int i=0;i<len;i++)
fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
return len;
}
void ntt(int *x,int len,int f)
{
for(int i=0;i<len;i++)
if(i<fi[i]) swap(x[i],x[fi[i]]);
for(int i=1;i<len;i<<=1)
{
int r=i<<1;
int wn=mypow(g,f*(p-1)/r);
for(int j=0;j<len;j+=r)
{
int w=1;
for(int k=0;k<i;k++)
{
int xx=x[j+k],yy=1ll*w*x[j+i+k]%p;
x[j+k]=(xx+yy)%p;
x[j+i+k]=((xx-yy)%p+p)%p;
w=1ll*w*wn%p;
}
}
}
if(f==-1)
{
int invn=mypow(len,p-2);
for(int i=0;i<len;i++)
x[i]=1ll*x[i]*invn%p;
}
}
void inv(int n,int *a,int *b)//n是项数
{
if(n==1)
{
b[0]=mypow(a[0],p-2);
return ;
}
inv((n+1)>>1,a,b);
int len=1,cnt=0;
while(len<=(n<<1)) len<<=1,cnt++;
for(int i=0;i<len;i++)
{
fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
c[i]=(i<n?a[i]:0);
b[i]=(i<n?b[i]:0);
}
ntt(c,len,1);
ntt(b,len,1);
for(int i=0;i<len;i++)
b[i]=(2-(ll)c[i]*b[i]%p+p)%p*b[i]%p;
ntt(b,len,-1);
for(int i=n;i<len;i++)
b[i]=0;
}
void dao(int *a,int *b,int n,int len)
{
for(int i=1;i<n;i++)
b[i-1]=(ll)i*a[i]%mod;
for(int i=n-1;i<len;i++)
b[i]=0;
}
void ji(int *a,int n,int len)
{
for(int i=n-2;i>=0;i--)
a[i+1]=(ll)a[i]*iinv[i+1]%mod;
a[0]=0;
for(int i=n;i<len;i++)
a[i]=0;
}
void ln(int *a,int n)//n是项数
{
int len=getlen(n-2,n-1);
dao(a,da,n,len);
inv(n-1,a,inva);
len=getlen(n-2,n-1);
ntt(da,len,1);
ntt(inva,len,1);
for(int i=0;i<len;i++) lna[i]=(ll)da[i]*inva[i]%mod;
ntt(lna,len,-1);
ji(lna,n,len);
}
void eexp(int *f,int *a,int n)//n是项数
{
if(n==1)//a[0]已经处理好了
return ;
eexp(f,a,(n+1)>>1);
ln(a,n);
int len=getlen(n-1,n-1);
for(int i=0;i<n;i++)
lna[i]=(-lna[i]+f[i]+mod)%mod;
lna[0]=(lna[0]+1)%mod;
ntt(a,len,1);
ntt(lna,len,1);
for(int i=0;i<len;i++) a[i]=1ll*a[i]*lna[i]%mod;
ntt(a,len,-1);
for(int i=n;i<len;i++)
a[i]=0;
}
int main(void)
{
scanf("%d%*c",&n);
get_inv(n);
char ch;
bool flag=false;
while(isdigit(ch=getchar()))
{
m=m*10+ch-'0';
mm=mm*10+ch-'0';
if(m>=mod) flag=true;
m%=mod;
mm%=(mod-1);
}
for(int i=0;i<n;i++)
scanf("%d",&f[i]);
int k=0;
while(k<n&&f[k]==0) k++;
if(1ll*k*m>=n||(flag&&k))
{
for(int i=0;i<n;i++)
printf("0 ");
putchar('\n');
return 0;
}
n-=k;
for(int i=0;i<n;i++) f[i]=f[i+k];
for(int i=n;i<n+k;i++) f[i]=0;
int f0=mypow(f[0],mm);
ln(f,n);
for(int i=0;i<n;i++) f[i]=m*lna[i]%mod,lna[i]=0;
a[0]=f0;
eexp(f,a,n);
int now=k*m;
n+=k;
for(int i=n-1;i>=now;i--) a[i]=a[i-now];
for(int i=now-1;i>=0;i--) a[i]=0;
for(int i=0;i<n;i++) printf("%d ",a[i]);
putchar('\n');
return 0;
}