推介这篇博客吧,写的真的挺好的.

算了说点自己的理解吧...

首先fftfft是解决两个多项式相乘把时间复杂度从O(N2)O(N^2)优化到O(NlogN)O(N*logN).学它可以解决一些用前缀不能解决的优化.

fftfft里面引进了复数,它是非常的妙的,无论是使得从多项式用系数表示转化成多项式用值来表示,还是用多项式用值来表示转化成多项式用系数表示.

假定我们知道两个多项式的系数表示,那么一定可以求出两个点在至少n+m+1n+m+1点的值,为什么要求这么多呢,因为其实可以用脑子想想就okok了,把多少次方全部看成未知数,那么nn元方程组由高斯消元或者其他东西都能知道至少需要n+1n+1个不同点的信息.

第一步就是用多项式AA的系数信息求2k2^k个点信息,方便分治.用多项式BB的系数信息求2k2^k个点信息.因为是乘法,所以它两直接相乘就得到了要求的多项式CC的点的信息,然后用一个转化把CC的点的信息当成系数,用单位根的倒数得到,CC原本的系数=求出来的点的信息/项数.

然后就可以用分治写出fftfft了,两个证明部分非常简单,建议看看博客,至此可以写出分治的fftfft了.

分治code:

#include <bits/stdc++.h>
using namespace std;

const int N=1e6+5;
const double pi=acos(-1);

struct cp{
	double x,y;
	cp(){x=y=0;}
	cp(double xx,double yy){x=xx,y=yy;}
}f[N<<2],g[N<<2],ans[N<<2];

cp operator + (cp A,cp B){
	return cp(A.x+B.x,A.y+B.y);
}

cp operator - (cp A,cp B){
	return cp(A.x-B.x,A.y-B.y);
}

cp operator * (cp A,cp B){
	return cp(A.x*B.x-A.y*B.y,A.x*B.y+A.y*B.x);
}

void fft(int n,cp *a,int op)//op等于1系数转点值 op等于-1点值转系数. 
{
	if(n<=1)	return;
	int mid=(n>>1);
	cp a1[mid],a2[mid];
	for(int i=0;i<mid;i++)
	{
		a1[i]=a[i<<1];
		a2[i]=a[i<<1|1];
	}
	fft(mid,a1,op);
	fft(mid,a2,op);
	cp w1(cos(pi/mid),sin(pi/mid)*op),wt,w(1,0);
	for(int i=0;i<mid;i++)
	{
		wt=w*a2[i];
		a[i]=a1[i]+wt;
		a[i+mid]=a1[i]-wt;
		w=w*w1;
	}
}

void run()
{
	int n,m;
	scanf("%d%d",&n,&m);
	for(int i=0;i<=n;i++)	
	{
		int x;
		scanf("%d",&x);
		f[i].x=x;
	}
	for(int i=0;i<=m;i++)	
	{
		int x;
		scanf("%d",&x);
		g[i].x=x;
	}
	int k=1;
	while(k<=n+m)	k<<=1;
	fft(k,f,1);
	fft(k,g,1);
	for(int i=0;i<k;i++)
		ans[i]=f[i]*g[i];
	fft(k,ans,-1);
	for(int i=0;i<=n+m;i++)
		printf("%.0f ",ans[i].x/k+0.5);
	puts("");
}

int main()
{
	int T=1;
//	scanf("%d",&T);
	while(T--)	run();
	return 0;
}

还有种实现也是分治不过是把分治改成了迭代,说实话个人觉得时间可能真差不多.但是快一点点吧,主要是优化了一些空间吧qwq. 实现也不难.

迭代code:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=4e6+5;
const int mod=998244353;
const double pi=acos(-1);

struct cp{
	double x,y;
	cp(){x=0,y=0;}
	cp(double X,double Y){x=X,y=Y;};
}f[N],g[N],ans[N];

cp operator + (cp A,cp B)
{
	return cp(A.x+B.x,A.y+B.y);
}

cp operator - (cp A,cp B)
{
	return cp(A.x-B.x,A.y-B.y);
}

cp operator * (cp A,cp B)
{
	return cp(A.x*B.x-A.y*B.y,A.x*B.y+A.y*B.x);
}

int b=0;
int rev[N];

void fft(int n,cp *a,int op)
{
	for(int i=0;i<n;i++)
		if(i<rev[i])	swap(a[i],a[rev[i]]);
	for(int len=1;len<=n/2;len<<=1)//枚举分治半长度. 
	{
		cp w1=cp(cos(pi/len),op*sin(pi/len));
		for(int i=0;i<=n-(len<<1);i+=(len<<1))
		{
			cp w=cp(1,0);
			for(int j=0;j<len;j++)
			{
				cp x=a[i+j];cp y=w*a[i+j+len];
				a[i+j]=x+y;
				a[i+j+len]=x-y;
				w=w*w1; 
			}
		}
	}
}

int main()
{
	int n,m;
	scanf("%d%d",&n,&m);
	for(int i=0;i<=n;i++)	scanf("%lf",&f[i].x);
	for(int i=0;i<=m;i++)	scanf("%lf",&g[i].x);
	int k=1;
	while(k<=n+m)	k<<=1,b++;
	for(int i=1;i<=k;i++)
		rev[i]=(rev[i>>1]>>1)+((i&1)<<(b-1));
	fft(k,f,1);
	fft(k,g,1);
	for(int i=0;i<=k;i++)
		ans[i]=f[i]*g[i];
	fft(k,ans,-1);
	for(int i=0;i<=(n+m);i++)
	{
		printf("%.0f ",ans[i].x/k+0.5);
	}
	puts("");
	return 0;
}

补充一下vector的两个多项式的乘法的板子ntt.


using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
const LL MOD = 998244353;
const int G = 3;
LL bin(LL x, LL n, LL MOD) {
    LL ret = MOD != 1;
    for (x %= MOD; n; n >>= 1, x = x * x % MOD)
        if (n & 1) ret = ret * x % MOD;
    return ret;
}
 
inline LL get_inv(LL x, LL p) { return bin(x, p - 2, p); }
 
LL wn[(N * 10) << 2], rev[(N * 10) << 2];
int NTT_init(int n_) {
    int step = 0; int n = 1;
    for ( ; n < n_; n <<= 1) ++step;
    FOR (i, 1, n)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (step - 1));
    int g = bin(G, (MOD - 1) / n, MOD);
    wn[0] = 1;
    for (int i = 1; i <= n; ++i)
        wn[i] = wn[i - 1] * g % MOD;
    return n;
}
 
void NTT(vector<LL>& a, int n, int f) {
    FOR (i, 0, n) if (i < rev[i])
        std::swap(a[i], a[rev[i]]);
    for (int k = 1; k < n; k <<= 1) {
        for (int i = 0; i < n; i += (k << 1)) {
            int t = n / (k << 1);
            FOR (j, 0, k) {
                LL w = f == 1 ? wn[t * j] : wn[n - t * j];
                LL x = a[i + j];
                LL y = a[i + j + k] * w % MOD;
                a[i + j] = (x + y) % MOD;
                a[i + j + k] = (x - y + MOD) % MOD;
            }
        }
    }
    if (f == -1) {
        LL ninv = get_inv(n, MOD);
        FOR (i, 0, n)
            a[i] = a[i] * ninv % MOD;
    }
}
 
vector<LL> operator+(vector<LL> a, const vector<LL>& b){
    a.resize(max(a.size(), b.size()));
    for(int i = 0; i < b.size(); ++ i)
        a[i] = (a[i] + b[i]) % MOD;
    return a;
}
 
vector<LL> conv(vector<LL> a, vector<LL> b) {
    int len = a.size() + b.size() - 1;
    int n = NTT_init(len);
    a.resize(n);
    b.resize(n);
    NTT(a, n, 1);
    NTT(b, n, 1);
    FOR (i, 0, n)
        a[i] = a[i] * b[i] % MOD;
    NTT(a, n, -1);
    a.resize(len);
    return a;
}