简介
FFT (法法塔)是个什么玩意?他的全名叫快速傅里叶变换(然而貌似和傅里叶并没有太大关系),用来快速求出多项式的点值表示,这个东西一般用来解决多项式相乘的问题。
一般的高精度乘法,我们有个 O(n2) 的普及组做法,然而这个做法远远不能满足现代小学生的需求(传言小学生都会FFT)。于是我们要学习更优复杂度的做法,那就是复杂度为 O(nlogn) 的 FFT。
下面,就让我们走进快速多项式乘法吧!
多项式
多项式有两种表示方式:
- 系数表示
这也是我们最常见的形式,即 f(x)=i=0∑naixi - 点值表示
对于一个多项式,我们把 n 个不同的 x 代入其系数表示,就得到了其点值表示,即 {(x1,f(x1))....,(xn,f(xn))},看起来很 naive。。
若 h(x)=f(x)g(x),则 h(xi)=f(xi)g(xi)(这tm不是废话
然而,点值表示是 FFT 的根本。
对于一个点值表示,都能唯一对应一个多项式,具体原因要涉及线性代数的知识,即是范德蒙矩阵可逆,这里略去(根本原因是我忘了怎么证)。这个性质是非常好的,然而,朴素去求 n 个点的点值复杂度是 O(n2) 的。快速傅里叶变换的思想,就是通过快速求出多项式的点值表示,来实现快速多项式乘法。
单位根
对于 xn=1 的解,我们称之为 n 次单位根,表示为 ωn。显然, ωnn=1。(这个真的很显然!!!!!)
这 n 个解,对应的其实是复平面上的 n 个点。(如果你不知道什么为虚数,可以先去百度一发,概括地讲,虚数也是一种数,可以进行加减乘除)
如图,这 n 个点就是 n 个单位根(横轴为实轴,纵轴为虚轴),让我们定义 ωn0 为 (1,0), ωn1 表示将 ωn0 逆时针旋转的第 1 个点, ωnk 表示将 ωn0 逆时针旋转的第 k 个点。实际上,这个 k 也表示 (ωn1)k。
这源自于复数计算的一个性质:复数相乘的二维意义就是模长相乘,幅角相加。
于是, ωnn=1,也可以看作是 ωn1 绕了一圈,回到了 (1,0)
DFT(离散傅里叶变换)
既然复数也是一种数,于是我们聪明的傅里叶同学就想到,为什么不将复数代入到多项式中呢?(这是傅里叶唯一的贡献)
于是,对于 n 次多项式,我们将 n 个 n 次单位根代入多项式中,便得到了原来多项式的一种点值表示 {(ωn0,f(ωn0)),(ωn1,f(ωn1)),....,(ωnn−1,f(ωnn−1))}
这就称为离散傅里叶变换(看起来并没有什么卵用)
FFT(快速傅里叶变换)
有同学肯定会问,得到这些单位根的点值表示,复杂度也是 O(n2) 的!
emmm确实如此。
但是单位根有一些神奇的性质(不然用它来干嘛
- w2n2k=wnk
- wnk+2n=−wnk
(读者自证不难
这个东西可以从单位根乘法的几何意义来理解。 (1) 比较显然, (2) 可以看成是单位根转了半圈。
有了这个又能干吗?
我们将原来的多项式按奇偶拆成两部分,即:
f(x)=(a0+a2x2+...+an−2xn−2)+(a1x+a3x3+...+an−1xn−1)
然后对这两部分分别考虑,设有以下两个多项式:
f1(x)=a0+a2x+...+an−2x2n−1f2(x)=a1+a3x+...+an−1x2n−1
那么不难看出:
f(x)=f1(x2)+xf2(x2)
然后我们将某个 n 单位根 ωns 和 ωns+2n代入,其中 0≤s≤2n(不要恐惧公式,原理真的很简单) f(ωns)=f1(ωn2s)+ωnsf2(ωn2s)f(ωns+2n)=f1(ωn2s+n)+ωns+2nf2(ωn2s+n)
那么由单位根性质 (1) 和 (2):
f(ωns)=f1(ω2ns)+ωnsf2(ω2ns)f(ωns+2n)=f1(ω2ns)−ωnsf2(ω2ns)
由上式看出,我们只要求出了 f1(x) 和 f2(x) 的点值表示,就能快速求出 f(x) 的点值表示。
不过,这个做法的前提是 n 为 2 的幂,所以我们在用 FFT 前,会先将次数拓展到 2 的幂,所以以下所有东西都是基于 n 是 2 的幂。
IDFT(逆离散傅里叶变换)
知道了点值表示,我们总得将它还原成多项式系数吧。不幸的是,朴素的还原也是 O(n2) 的。不过牛逼的傅里叶,提出了离散傅里叶变换的逆变换,让我们有了希望。
实际上,我们让 g(x)=i=0∑nf(ωni)xi
这是一个新的多项式,如果我们分别将单位根的逆 {ωn0,ωn−1....,ωn−n} 代入 g(x) 中,会有极其神奇的事情发生。
在这里直接给出结论:
g(ωn−i)=n×ai
神奇吧,也就是说,我们可以用 FFT 来将点值表示快速还原成系数表示。
递归版 FFT
有同学可能会说:到这里还没告诉我怎么进行多项式乘法呢!
其实,只要快速求出 f(x) 和 g(x) 的点值表示,就能 O(n) 求出 h(x)=f(x)g(x) 的点值表示,然后进行 IDFT,就可以得到 h(x) 的系数了。
代码如下:
fft(a, 1);
fft(b, 1);
for(i = 0; i < limit; i++) a[i] = a[i] * b[i];
fft(a, -1);
代码中, FFT(a,1) 表示对 a 进行 DFT, FFT(a,−1) 表示对 a 进行 IDFT。
至此,不难写出 FFT 的递归版本:
#include <bits/stdc++.h>
#define N 4000005
using namespace std;
const double pi = acos(-1.0);
struct Complex{
double x, y;
Complex(double xx = 0, double yy = 0){x = xx, y = yy;}
}a[N], b[N];
Complex operator + (Complex a, Complex b){ return Complex(a.x + b.x, a.y + b.y);}
Complex operator - (Complex a, Complex b){ return Complex(a.x - b.x, a.y - b.y);}
Complex operator * (Complex a, Complex b){ return Complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}
int read(){
int x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
return x * f;
}
void fft(int limit, Complex *a, int type){
if(limit == 1) return;
int i;
Complex a1[limit >> 1], a2[limit >> 1];
for(i = 0; i <= limit; i += 2) a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];
fft(limit >> 1, a1, type);
fft(limit >> 1, a2, type);
Complex Wn = Complex(cos(2.0 * pi / limit), type * sin(2.0 * pi / limit)), w = Complex(1, 0);
for(i = 0; i < (limit >> 1); i++, w = w * Wn){
a[i] = a1[i] + w * a2[i];
a[i + (limit >> 1)] = a1[i] - w * a2[i];
}
}
int main(){
int i, j, n, m, limit = 1;
n = read(); m = read();
for(i = 0; i <= n; i++) a[i].x = read();
for(i = 0; i <= m; i++) b[i].x = read();
while(limit <= n + m) limit <<= 1;
fft(limit, a, 1);
fft(limit, b, 1);
for(i = 0; i <= limit; i++) a[i] = a[i] * b[i];
fft(limit, a, -1);
for(i = 0; i <= n + m; i++) printf("%d ", int(a[i].x / limit + 0.5));
return 0;
}
然而,我的递归版 FFT 效率感人,且又 WA 又 T,让我悲痛欲绝。于是我要介绍迭代版的 FFT
迭代版 FFT
蝴蝶操作
不要被名字吓到,我根本不知道为什么叫蝴蝶操作。。
我们按奇偶来操作,如下图:
我们惊奇地发现,原序列和后序列的区别,在于原序列是后序列的二进制翻转!
于是我们很容易求出后序列是哪些,然后从下往上更新。
求后序列的代码如下:
for(i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
只是一个小小的 dp。
三次变两次
原来我们是要进行三次 FFT 的,但是如果我们将 g(x) 放到 f(x) 的虚部,就可以只用两次 FFT,常数降为 32。
证明:
设 p(x)=f(x)+g(x)i,则 p2(x)=f2(x)−g2(x)+2f(x)g(x)i
将虚部除以 2 就是我们要求的东西。
总代码如下
#include <bits/stdc++.h>
#define N 4000005
using namespace std;
struct Complex{
double x, y;
}a[N], b[N];
const double pi = acos(-1.0);
Complex operator + (Complex a, Complex b){ return Complex{a.x + b.x, a.y + b.y};};
Complex operator - (Complex a, Complex b){ return Complex{a.x - b.x, a.y - b.y};};
Complex operator * (Complex a, Complex b){ return Complex{a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x};};
int rev[N], limit = 1, len;
int read(){
int x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - 48;
while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
return x * f;
}
void fft(Complex *a, int type){
int i, j, k, mid, R;
Complex Wn, w;
for(i = 0; i < limit; i++){
if(i < rev[i]) swap(a[i], a[rev[i]]);//得到后序列,i < rev[i] 是保证只交换一次
}
for(mid = 1; mid < limit; mid <<= 1){//mid 是每次要处理序列长度的一半
Wn = Complex{cos(pi / mid), type * sin(pi / mid)};//得到单位根,角度是 2 * pi / 2 * mid, 2被约掉了
for(R = (mid << 1), j = 0; j < limit; j += R){//枚举序列左端点
w = Complex{1, 0};//得到单位根的 0 次方
for(k = 0; k < mid; k++, w = w * Wn){//j + k是在序列中的位置,同时得到单位根的 k 次方
Complex x = a[j + k], y = a[j + mid + k] * w;//由单位根的性质(1),(2) 推导而来
a[j + k] = x + y;
a[j + mid + k] = x - y;
}
}
}
}
int main(){
int i, j, n, m;
n = read(); m = read();
while(limit <= n + m) limit <<= 1, len++;//找到大于 n + m 的最小的 2 的幂
for(i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));//预处理后序列
for(i = 0; i <= n; i++) a[i].x = read();//快读,降低代码常数
for(i = 0; i <= m; i++) a[i].y = read();//三次变两次优化,将多项式 b 读入 a 的虚部
fft(a, 1);//DFT
for(i = 0; i < limit; i++) a[i] = a[i] * a[i];//平方
fft(a, -1);//IDFT
for(i = 0; i <= n + m; i++) printf("%d ", int(a[i].y / (2 * limit) + 0.5));//记得四舍五入,否则精度会有问题
return 0;
}
UPD:
代码加了注释。