「算法笔记」快速傅里叶变换(FFT)

时间:2023-03-09 02:47:09
「算法笔记」快速傅里叶变换(FFT)

一、引入

首先,定义多项式的形式为 \(f(x)=\sum_{i=0}^n a_ix^i\),其中 \(a_i\) 为系数,\(n\) 为次数,这种表示方法称为“系数表示法”,一个多项式是由其系数确定的。

可以证明,\(n+1\) 个点可以唯一确定一个 \(n\) 次多项式。对于 \(f(x)\),代入 \(n+1\) 个不同的 \(x\),得到 \(n+1\) 个不同的 \(y\)。一个 \(n\) 次的多项式就可以等价地换成 \(n+1\) 个等式,相当于平面上的 \(n+1\) 组坐标 \((x_i,y_i)\),这种表示方法称为“点值表示法”

多项式乘法 (卷积):设 \(C(x)=A(x)\cdot B(x)\),\(A,B,C\) 的系数构成的数列分别为 \(a,b,c\),则 \(c_k=\sum_{i=0}^ka_ib_{k-i}\)。理解:因为 \(x^i\times x^{k-i}=x^k\),\(a_i\) 和 \(b_{k-i}\) 相乘后,它们后面的未知数就变成了 \(x^k\),对 \(c_k\) 产生贡献。

暴力求解:两个 \(n\) 次多项式相乘,时间复杂度 \(\mathcal{O}(n^2)\)。

快速傅里叶变换(Fast Fourier Transform,简称 FFT)可以 \(\mathcal{O}(n\log n)\) 求解。

二、基本步骤

利用点值表示法,分三步快速求出多项式乘积:

  1. 由系数表示法转换成点值表示法。(DFT)
  2. 利用点值表示法,求两个多项式的乘积。
  3. 再将点值表示法转化成系数表示法。(IDFT)

对于步骤二,给出两个 \(n\) 次多项式 \(A\) 和 \(B\) 的点值表达式,我们可以 \(\mathcal{O}(n)\) 求出其乘积 \(C\) 的点值表达式。显然 \(C\) 是 \(2n\) 次的,取 \(2n+1\) 个不同的 \(x_i\)。

  • 代入 \(A\):\(\{(x_0,y_0),(x_1,y_1),\cdots,(x_{2n},y_{2n})\}\)。
  • 代入 \(B\):\(\{(x_0,y'_0),(x_1,y'_1),\cdots,(x_{2n},y'_{2n})\}\)。
  • 则 \(C\) 的点值表达式为:\(\{(x_0,y_0y'_0),(x_1,y_1y'_1),\cdots,(x_{2n},y_{2n}y'_{2n})\}\)。

下面重点分析步骤一。有一些前置概念。

三、复数

我们把形如 \(z=a+bi\)(\(a,b\) 为实数)的数称为 复数,其中,\(a,b\) 分别叫做复数 \(z\) 的实部与虚部,\(i\) 为虚数单位,\(i^2=-1\)。

我们把复数表示在 复平面 上(\(x\) 轴叫实轴,\(y\) 轴叫虚轴),就像把实数表示在数轴上。如图,复数 \(z=a+bi\) 与复平面内的点 \(Z(a,b)\) 一一对应。

连接 \(OZ\)。复数 \(z=a+bi\) 与平面向量 \(\vec{OZ}\) 一一对应。

「算法笔记」快速傅里叶变换(FFT)

复数的四则运算:

  • 加减法:\((a+bi)\pm (c+di)=(a\pm c)+(b\pm d)i\)。
  • 乘法:\((a+bi)(c+di)=ac+adi+bci+bdi^2=(ac-bd)+(bc+ad)i\)。
  • 除法:\(\frac{a+bi}{c+di}=\frac{(a+bi)\times (c-di)}{(c+di)\times(c-di)}=\frac{(ac+bd)+(bc-ad)i}{c^2+d^2}=\frac{ac+bd}{c^2+d^2}+\frac{bc-ad}{c^2+d^2}i\)。
complex<double>a;    //STL。a.real() 返回复数 a 的实部,a.imag() 返回复数 a 的虚部
struct cp{ //手写,比 STL 快一点
double a,b;
cp operator+(cp &x){return (cp){a+x.a,b+x.b};}
cp operator-(cp &x){return (cp){a-x.a,b-x.b};}
cp operator*(cp &x){return (cp){a*x.a-b*x.b,a*x.b+b*x.a};}
cp operator/(cp &x){double v=x.a*x.a+x.b*x.b; return (cp){(a*x.a+b*x.b)/v,(b*x.a-a*x.b)/v};} //除法一般不用 qwq
};

我们称 \(\overline{z}=a-bi\) 为复数 \(z=a+bi\) 的 共轭复数

对于复数 \(z=a+bi\),模长 \(|z|\) 为点 \(Z(a,b)\) 到原点的距离,幅角 为对应的向量 \(\vec{OZ}\) 与横轴正半轴的夹角。

复数的乘法可以表达为:模长相乘,辐角相加

四、单位根

1. 定义

\(n\) 次单位根是满足 \(x^n=1\) 的复数 \(x\)。

首先,单位根的模长必然为 \(1\)。因为若 \(|x|>1\),则 \(|x^n|=|x|^n>1\);若 \(|x|<1\),则 \(|x^n|=|x|^n<1\)。

所以单位根表示的点一定在 单位圆(圆心为原点,半径为 \(1\))上。

其次,单位根的辐角 \(\theta\) 一定满足 \(\frac{n\theta}{2\pi}\in\mathbb{Z}\)。也就是一个向量从 \((1,0)\) 开始,每次旋转 \(\theta\) 的角度,旋转 \(n\) 次后还落在 \((1,0)\) 上,那么它一定旋转了整数圈。

然后发现 \(n\) 次单位根正好是模长为 \(1\),辐角为 \(\frac{2k\pi}{n}\) 的向量对应的复数。

记模长为 \(1\),辐角为 \(\frac{2k\pi}{n}\) 的向量对应的 \(n\) 次单位根为 \(\omega_n^k\),称为第 \(k\) 个 \(n\) 次单位根。

还能发现,\(\omega_n^k=\omega_n^{k\bmod n}\),所以一般情况下,我们认为 \(n\) 次单位根有 \(n\) 个,即 \(\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1}\)。

2. 性质

单位根的性质:(这些性质在后文会被用到)

  • 性质 1:\(n\) 次单位根对应的向量将单位圆 \(n\) 等分。

    两个相邻的 \(n\) 次单位根对应的向量的夹角相等。单位根的辐角是周角的 \(\frac{1}{n}\)。

  • 性质 2:\(\omega_n^k=\omega_n^{k\bmod n}\)。

    在弧度制下,任意弧度 \(\theta\) 与 \(\theta+2k\pi\,(k\in\mathbb{Z})\) 表示相同的角。

  • 性质 3:\({(\omega_n^k)}^p=\omega_n^{kp}\)。

    第 \(k\) 个 \(n\) 次单位根对应向量的辐角变为原来的 \(p\) 倍,相当于 \(\omega_n^{kp}\) 对应的向量。

  • 性质 4:\(\omega_{dn}^{dk}=\omega_{n}^k\)。

    考虑两者对应向量的辐角,\(\frac{2dk\pi}{dn}=\frac{2k\pi}{n}\)。也可以这样理解:\(dn\) 次单位根对应的向量将单位圆 \(dn\) 等分,取第 \(dk\) 个。\(n\) 次单位根对应的向量将单位圆 \(n\) 等分,取第 \(k\) 个。两者等价。

  • 性质 5:\(\omega_n^{k+n/2}=-\omega_n^k\)。其中 \(n\) 为偶数。

    相当于一个复数对应的向量进行一次中心对称,\(a+bi\) 变为 \(-a-bi\)。

根据性质 3 有,\({(\omega_n^k)}^2=\omega_{n}^{2k}\)。根据性质 4 有,\(\omega_n^{2k}=\omega_{n/2}^k\),其中 \(n\) 为偶数。

3. 求法

根据性质 3,有 \(\omega_n^k=(\omega_n^1)^k\)。也就是说,只要求出 \(\omega_n^1\),就能得到 \(\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1}\)。

\(\omega_n^1\) 所对应的向量模长为 \(1\),辐角为 \(\frac{2\pi}{n}\),得到 \(\omega_n^1\) 所对应的点为 \((\cos(\frac{2\pi}{n}),\sin(\frac{2\pi}{n}))\)。

求 \(\pi\):double pi=acos(-1)

补充:\(\omega_n^k=e^{\frac{2\pi ik}{n}}=\cos(\frac{2\pi k}{n})+i\cdot \sin(\frac{2\pi k}{n})\)。其中 \(i\) 为虚数单位。

五、DFT

将系数表示法转换成点值表示法。

1. 基本思路

对于 \(n-1\) 次多项式(也就是有 \(n\) 项) \(f(x)=\sum_{i=0}^{n-1} a_ix^i\),将奇偶次数分离。

\(f(x)=(a_0+a_2x^2+a_4x^4+\cdots+a_{n-2}x^{n-2})+(a_1x+a_3x^3+a_5x^5+\cdots+a_{n-1}x^{n-1})\)

定义两个新的多项式 \(f_1(x)\) 和 \(f_2(x)\):

  • \(f_1(x)=a_0+a_2x+a_4x^2+\cdots+a_{n-2}x^{{n/2-1}}\)。
  • \(f_2(x)=a_1+a_3x+a_5x^2+\cdots+a_{n-1}x^{n/2-1}\)。

于是有 \(f(x)=f_1(x^2)+xf_2(x^2)\)。

将 \(\omega_n^k\,(k<\frac{n}{2})\) 代入得:\(f(\omega_n^k)=f_1(\omega_n^{2k})+\omega_n^kf_2(\omega_n^{2k})=f_1(\omega_{n/2}^k)+\omega_n^kf_2(\omega_{n/2}^k)\)。

同理,将 \(\omega_n^{k+n/2}\,(k<\frac{n}{2})\) 代入得:\(f(\omega_n^{k+n/2})=f_1(\omega_n^{2k+n})+\omega_n^{k+n/2}f_2(\omega_n^{2k+n})\)

\(=f_1(\omega_n^{2k})+\omega_n^{k+n/2}f_2(\omega_n^{2k})=f_1(\omega_{n/2}^k)+\omega_n^{k+n/2}f_2(\omega_{n/2}^k)=f_1(\omega_{n/2}^k)-\omega_n^kf_2(\omega_{n/2}^k)\)。

发现两者的右边只有正负号的区别。

第一个式子的 \(k\) 取遍 \([0,\frac{n}{2}-1]\) 时,\(k+\frac{n}{2}\) 取遍 \([\frac{n}{2},n-1]\)。

如果我们知道 \(f_1(x),f_2(x)\) 分别在 \(x=\omega_{n/2}^0,\omega_{n/2}^1,\cdots,\omega_{n/2}^{n/2-1}\) 的点值表示,就可以 \(\mathcal{O}(n)\) 求出 \(f(x)\) 在 \(x=\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1}\) 的点值表示。

然后,发现 \(f_1(x),f_2(x)\) 和 \(f(x)\) 的性质完全相同,这样就把问题分成了两个子问题,对于这两个子问题再进行递归求解。这样就可以在 \(\mathcal{O}(n\log n)\) 的时间复杂度内求出点值表达式。

2. 代码实现

DFT 是利用单位根的特殊性质进行分治。

考虑到它能处理的多项式长度只能为 \(2^k\)(\(k\) 为整数),否则在分治时左右的项数就会不同。我们可以在最高次补一些系数为 \(0\) 的项,把原来的 \(n\) 补到 \(2^k\,(2^k\geq n)\)。这样不影响计算结果。

别忘了:\(f(\omega_n^k)=f_1(\omega_{n/2}^k)+\omega_n^kf_2(\omega_{n/2}^k),f(\omega_n^{k+n/2})=f_1(\omega_{n/2}^k)-\omega_n^kf_2(\omega_{n/2}^k)\)。

递归实现 DFT:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+5;
int n,len;
double x,pi=acos(-1);
complex<double>a[N];
void DFT(complex<double>*a,int n){
if(n==1) return ; //边界
int m=n/2;
complex<double>a1[m],a2[m];
for(int i=0;i<m;i++) a1[i]=a[i<<1],a2[i]=a[i<<1|1]; //按奇偶分类
DFT(a1,m),DFT(a2,m); //处理子问题
complex<double>x(cos(2*pi/n),sin(2*pi/n)),w(1,0); //x: w_n^1。w: 当前的 n 次单位根,初始值为 w_n^0,即 1。
for(int i=0;i<m;i++)
a[i]=a1[i]+w*a2[i],a[i+m]=a1[i]-w*a2[i],w*=x; //w*=x: 得到下一个单位根。
}
signed main(){
scanf("%lld",&n);
for(int i=0;i<n;i++)
scanf("%lf",&x),a[i]=x; //a[i] 的实部赋值为 x
for(len=1;len<n;len<<=1); //把项数补到 2 的整幂,高次项的系数默认为 0
DFT(a,len);
for(int i=0;i<len;i++)
printf("(%.4lf,%.4lf)\n",a[i].real(),a[i].imag());
return 0;
}

六、IDFT

DFT 的逆运算。将点值表示法转化成系数表示法。

结论:把 DFT 中的 \(\omega_n^1\) 换成它的共轭复数,即 \((\cos(\frac{2\pi}{n}),-\sin(\frac{2\pi}{n}))\),得到的系数再除以 \(n\) 即可。证明略。

可以把 IDFT 和 DFT 放在一起写。

void FFT(complex<double>*a,int n,int opt){    //opt=1 为 DFT,opt=-1 为 IDFT
if(n==1) return ;
int m=n/2;
complex<double>a1[m],a2[m];
for(int i=0;i<m;i++) a1[i]=a[i<<1],a2[i]=a[i<<1|1];
FFT(a1,m,opt),FFT(a2,m,opt);
complex<double>x(cos(2*pi/n),sin(2*pi/n)*opt),w(1,0);
for(int i=0;i<m;i++)
a[i]=a1[i]+w*a2[i],a[i+m]=a1[i]-w*a2[i],w*=x;
}

七、迭代实现

目前的代码如下。可以发现它的效率并不是很高。

//Luogu P3803
#include<bits/stdc++.h>
using namespace std;
const int N=3e6+5;
int n,m,len;
double x,pi=acos(-1);
complex<double>a[N],b[N];
void FFT(complex<double>*a,int n,int opt){ //opt=1/-1: DFT/IDFT
if(n==1) return ;
int m=n/2;
complex<double>a1[m],a2[m];
for(int i=0;i<m;i++) a1[i]=a[i<<1],a2[i]=a[i<<1|1];
FFT(a1,m,opt),FFT(a2,m,opt);
complex<double>x(cos(2*pi/n),sin(2*pi/n)*opt),w(1,0);
for(int i=0;i<m;i++)
a[i]=a1[i]+w*a2[i],a[i+m]=a1[i]-w*a2[i],w*=x; //蝴蝶操作(只是一个名字 qwq)。这里 w*a2[i] 算了两次,先记录下来再算可以减小常数。
}
signed main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)
scanf("%lf",&x),a[i]=x;
for(int i=0;i<=m;i++)
scanf("%lf",&x),b[i]=x;
n=n+m+1; //一个 n 次多项式和一个 m 次多项式的乘积是一个 n+m 次多项式 (有 n+m+1 项)
for(len=1;len<n;len<<=1);
FFT(a,len,1),FFT(b,len,1);
for(int i=0;i<len;i++) a[i]*=b[i]; //点值直接乘
FFT(a,len,-1);
for(int i=0;i<n;i++)
printf("%d%c",(int)(a[i].real()/len+0.5),i==n-1?'\n':' '); //注意这里是除以 len
return 0;
}

如图,考虑递归的结构:

「算法笔记」快速傅里叶变换(FFT)

求出第 \(4\) 层的状态,就能往上合并求出前 \(3\) 层的状态。

观察发现,最后的数组下标的序列是原序列的 二进制翻转。比如 \(6=(110)_2\),反过来就是 \((011)_2=3\)。而第 \(4\) 层 \(a_3\) 的位置就是原来 \(a_6\) 的位置。

//r[i] 表示 i 二进制翻转后的结果。求 0~len-1 在二进制位数为 log2(len)-1 意义下的二进制翻转。
for(int i=0;i<len;i++) //len 是 2 的幂次
r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0);

理解:考虑 \(i\) 与 \(\frac{i}{2}\) 在二进制下的关系。\(i\) 可以看作是 \(\frac{i}{2}\) 在二进制下的每一位左移一位得到。翻转后,\(i\) 是 \(\frac{i}{2}\) 在二进制下的每一位右移一位得到,然后判一下最后一位即可。

迭代实现:

//Luogu P3803
#include<bits/stdc++.h>
using namespace std;
const int N=3e6+5;
int n,m,len,r[N];
double x,pi=acos(-1);
complex<double>a[N],b[N];
void FFT(complex<double>*a,int n,int opt){ //opt=1/-1: DFT/IDFT
for(int i=0;i<n;i++)
if(i<r[i]) swap(a[i],a[r[i]]); //求出最后一层的序列
for(int k=2;k<=n;k<<=1){ //枚举区间长度
int m=k>>1; //待合并的长度
complex<double>x(cos(2*pi/k),sin(2*pi/k)*opt),w(1,0),v;
for(int i=0;i<n;i+=k,w=1) //枚举起始点
for(int j=i;j<i+m;j++) //遍历区间
v=w*a[j+m],a[j+m]=a[j]-v,a[j]=a[j]+v,w*=x; //蝴蝶操作。注意先 a[j+m]=a[j]-v 再 a[j]=a[j]+v。
}
}
signed main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)
scanf("%lf",&x),a[i]=x;
for(int i=0;i<=m;i++)
scanf("%lf",&x),b[i]=x;
n=n+m+1;
for(len=1;len<n;len<<=1);
for(int i=0;i<len;i++) //二进制翻转
r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0);
FFT(a,len,1),FFT(b,len,1);
for(int i=0;i<len;i++) a[i]*=b[i];
FFT(a,len,-1);
for(int i=0;i<n;i++)
printf("%d%c",(int)(a[i].real()/len+0.5),i==n-1?'\n':' ');
return 0;
}

八、模板

P3803 【模板】多项式乘法(FFT) 为例。

递归实现:

#include<bits/stdc++.h>
using namespace std;
const int N=3e6+5;
int n,m,len;
double x,pi=acos(-1);
complex<double>a[N],b[N];
void FFT(complex<double>*a,int n,int opt){ //opt=1/-1: DFT/IDFT
if(n==1) return ;
int m=n/2;
complex<double>a1[m],a2[m];
for(int i=0;i<m;i++) a1[i]=a[i<<1],a2[i]=a[i<<1|1];
FFT(a1,m,opt),FFT(a2,m,opt);
complex<double>x(cos(2*pi/n),sin(2*pi/n)*opt),w(1,0);
for(int i=0;i<m;i++)
a[i]=a1[i]+w*a2[i],a[i+m]=a1[i]-w*a2[i],w*=x;
}
signed main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)
scanf("%lf",&x),a[i]=x;
for(int i=0;i<=m;i++)
scanf("%lf",&x),b[i]=x;
n=n+m+1;
for(len=1;len<n;len<<=1);
FFT(a,len,1),FFT(b,len,1);
for(int i=0;i<len;i++) a[i]*=b[i];
FFT(a,len,-1);
for(int i=0;i<n;i++)
printf("%d%c",(int)(a[i].real()/len+0.5),i==n-1?'\n':' ');
return 0;
}

迭代实现:(比递归快)

#include<bits/stdc++.h>
using namespace std;
const int N=3e6+5;
int n,m,len,r[N];
double x,pi=acos(-1);
complex<double>a[N],b[N];
void FFT(complex<double>*a,int n,int opt){ //opt=1/-1: DFT/IDFT
for(int i=0;i<n;i++)
if(i<r[i]) swap(a[i],a[r[i]]);
for(int k=2;k<=n;k<<=1){
int m=k>>1;
complex<double>x(cos(2*pi/k),sin(2*pi/k)*opt),w(1,0),v;
for(int i=0;i<n;i+=k,w=1)
for(int j=i;j<i+m;j++) v=w*a[j+m],a[j+m]=a[j]-v,a[j]=a[j]+v,w*=x;
}
}
signed main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)
scanf("%lf",&x),a[i]=x;
for(int i=0;i<=m;i++)
scanf("%lf",&x),b[i]=x;
n=n+m+1;
for(len=1;len<n;len<<=1);
for(int i=0;i<len;i++)
r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0);
FFT(a,len,1),FFT(b,len,1);
for(int i=0;i<len;i++) a[i]*=b[i];
FFT(a,len,-1);
for(int i=0;i<n;i++)
printf("%d%c",(int)(a[i].real()/len+0.5),i==n-1?'\n':' ');
return 0;
}

记忆:

  • \(f(x)=f_1(x^2)+xf_2(x^2)\)。
  • \(f(\omega_n^k)=f_1(\omega_{n/2}^k)+\omega_n^kf_2(\omega_{n/2}^k)\)。
  • \(f(\omega_n^{k+n/2})=f_1(\omega_{n/2}^k)-\omega_n^kf_2(\omega_{n/2}^k)\)。