NTT模板:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=3e5+10,mod=998244353;
inline int qpow(int x,int y) {
int res(1);
while(y) {
if(y&1) res=1LL*res*x%mod;
x=1LL*x*x%mod;
y>>=1;
}
return res;
}
int r[maxn];
inline void ntt(int *x,int lim,int opt) {
int i,j,k,m,gn,g,tmp;
for(i=0; i<lim; ++i)
if(i<r[i]) swap(x[i],x[r[i]]);
for(m=2; m<=lim; m<<=1) {
k=m>>1;
gn=qpow(3,(mod-1)/m);
for(i=0; i<lim; i+=m) {
g=1;
for(j=0; j<k; ++j,g=1LL*g*gn%mod) {
tmp=1LL*x[i+j+k]*g%mod;
x[i+j+k]=(x[i+j]-tmp+mod)%mod;
x[i+j]=(x[i+j]+tmp)%mod;
}
}
}
if(opt==-1) {
reverse(x+1,x+lim);
int inv=qpow(lim,mod-2);
for(i=0; i<lim; ++i) x[i]=1LL*x[i]*inv%mod;
}
}
int A[maxn],B[maxn],C[maxn];
char a[maxn],b[maxn];
int main() {
scanf("%s",a);
scanf("%s",b);
int i,lim(1),len1,len2;
len1=strlen(a),len2=strlen(b);
for(i=0; i<len1; ++i) A[i]=a[len1-i-1]-'0';
for(i=0; i<len2; ++i) B[i]=b[len2-i-1]-'0';
while(lim<(len1<<1)) lim<<=1;
while(lim<(len2<<1)) lim<<=1;
for(i=0; i<lim; ++i) r[i]=(i&1)*(lim>>1)+(r[i>>1] >> 1);
ntt(A,lim,1);
ntt(B,lim,1);
for(i=0; i<lim; ++i) C[i]=1LL*A[i]*B[i]%mod;
ntt(C,lim,-1);
int len(0);
for(i=0; i<lim; ++i) {
if(C[i]>=10) len=i+1,C[i+1]+=C[i]/10,C[i]%=10;
if(C[i]&&i>len) len=i;
}
while(C[len]>=10) C[len+1]+=C[len]/10,C[len]%=10,++len;
for(i=len; ~i; --i) putchar(C[i]+'0');
puts("");
return 0;
}
只需要把上面代码的主函数部分改成如下就行:
int main() {
while(~scanf("%s",a)) {
scanf("%s",b);
int i,lim(1),len1,len2;
len1=strlen(a),len2=strlen(b);
for(i=0;i<len1;++i) A[i]=a[len1-i-1]-'0';
for(i=0;i<len2;++i) B[i]=b[len2-i-1]-'0';
while(lim<(len1<<1)) lim<<=1;
while(lim<(len2<<1)) lim<<=1;
for(i=0;i<lim;++i) r[i]=(i&1)*(lim>>1)+(r[i>>1] >> 1);
ntt(A,lim,1);
ntt(B,lim,1);
for(i=0;i<lim;++i) C[i]=1LL*A[i]*B[i]%mod;
ntt(C,lim,-1);
int len(0);
for(i=0;i<lim;++i) {
if(C[i]>=10) len=i+1,C[i+1]+=C[i]/10,C[i]%=10;
if(C[i]&&i>len) len=i;
}
while(C[len]>=10) C[len+1]+=C[len]/10,C[len]%=10,++len;
for(i=len;~i;--i) putchar(C[i]+'0');
puts("");
for(int i=0;i<lim;++i) A[i]=B[i]=0;
}
return 0;
}
FFT模板:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const double PI=acos(-1.0);
struct Complex {
double x,y;
Complex(double _x=0.0,double _y=0.0) {
x=_x;
y=_y;
}
Complex operator-(const Complex &b) const {
return Complex(x-b.x,y-b.y);
}
Complex operator+(const Complex &b) const {
return Complex(x+b.x,y+b.y);
}
Complex operator*(const Complex &b) const {
return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
}
};
const int maxn=3e5+7;
int r[maxn];
void fft(Complex *y,int len,int opt) {
for(int i=0; i<len; ++i)
if(i<r[i]) swap(y[i],y[r[i]]);
for(int h=2; h<=len; h<<=1) {
Complex wn(cos(2*PI/h),sin(opt*2*PI/h));
for(int j=0; j<len; j+=h) {
Complex w(1,0);
for(int k=j; k<j+h/2; ++k) {
Complex tmp=w*y[k+h/2];
y[k+h/2]=y[k]-tmp;
y[k]=y[k]+tmp;
w=w*wn;
}
}
}
if(opt==-1) {
for(int i=0; i<len; ++i) y[i].x/=len;
}
}
Complex x1[maxn],x2[maxn];
char a[maxn/2],b[maxn/2];
int sum[maxn];
int main() {
scanf("%s%s",a,b);
int len1=strlen(a);
int len2=strlen(b);
int len=1;
for(int i=0; i<len1; ++i) x1[i]=Complex(a[len1-1-i]-'0',0);
for(int i=0; i<len2; ++i) x2[i]=Complex(b[len2-1-i]-'0',0);
while(len<(len1<<1)) len<<=1;
while(len<(len2<<1)) len<<=1;
for(int i=1; i<len; ++i) r[i]=(i&1)*(len>>1)+(r[i>>1]>>1);
fft(x1,len,1);
fft(x2,len,1);
for(int i=0; i<len; ++i) x1[i]=x1[i]*x2[i];
fft(x1,len,-1);
for(int i=0; i<len; ++i) sum[i]=int(x1[i].x+0.5);
int n(0);
for(int i=0; i<len; ++i) {
if(sum[i]>=10) n=i+1,sum[i+1]+=sum[i]/10,sum[i]%=10;
if(sum[i]&&i>n) n=i;
}
while(sum[n]>=10) sum[n+1]+=sum[n]/10,sum[n]%10,++n;
for (int i = n; ~i; --i) putchar(sum[i] + '0');
puts("");
return 0;
}
把主函数改成如下就是多组输入的格式:
int main() {
while(scanf("%s%s",a,b)==2) {
int len1=strlen(a);
int len2=strlen(b);
int len=1;
for(int i=0;i<len1;++i) x1[i]=Complex(a[len1-1-i]-'0',0);
for(int i=0;i<len2;++i) x2[i]=Complex(b[len2-1-i]-'0',0);
while(len<(len1<<1)) len<<=1;
while(len<(len2<<1)) len<<=1;
for(int i=1;i<len;++i) r[i]=(i&1)*(len>>1)+(r[i>>1]>>1);
fft(x1,len,1);
fft(x2,len,1);
for(int i=0;i<len;++i) x1[i]=x1[i]*x2[i];
fft(x1,len,-1);
for(int i=0;i<len;++i) sum[i]=int(x1[i].x+0.5);
int n(0);
for(int i=0;i<len;++i) {
if(sum[i]>=10) n=i+1,sum[i+1]+=sum[i]/10,sum[i]%=10;
if(sum[i]&&i>n) n=i;
x1[i]=x2[i]=Complex(0,0);
}
while(sum[n]>=10) sum[n+1]+=sum[n]/10,sum[n]%10,++n;
for (int i = n; ~i; --i) putchar(sum[i] + '0');
puts("");
}
return 0;
}