快速傅里叶变换和逆变换的C++实现

时间:2022-05-05 09:53:45

近来做一个大整数乘法的ACM题目时候一直被运行超时所困扰,上网搜索下后发现需要用到快速傅里叶变换和逆变换的算法来实现大整数乘法,才能把复杂度降到LogN * N.

看了一个星期的资料, 吃透了算法才把完整的代码敲出来^^。


代码里有三个主要的函数, 具体约束和说明见代码注释

FFT                              //计算区间内(左闭右开)的复数的离散傅里叶变换(按时间变换DIT-FFT)

IFFT                             //计算区间内(左闭右开)的复数的离散傅里叶逆变换(按时间变换DIT-IFFT)

RaderSort                    //对目标区间(左闭右开)元素进行倒位序排序,雷德算法


使用方法示例 1:

vector<complex<double>> I;
I.push_back(complex<double>(8, 0));
I.push_back(complex<double>(7, 0));
I.push_back(complex<double>(6, 0));
I.push_back(complex<double>(0, 0));
I.push_back(complex<double>(0, 0));
I.push_back(complex<double>(0, 0));
I.push_back(complex<double>(0, 0));
I.push_back(complex<double>(0, 0));

FFT(I.begin(), I.end()); //正向变换
IFFT(I.begin(), I.end()); //逆变换


使用方法示例 2:

<pre name="code" class="cpp">complex<double> K[8];
K[0] = complex<double>(8,0);
K[1] = complex<double>(7, 0);
K[2] = complex<double>(6, 0);
K[3] = complex<double>(0, 0);
K[4] = complex<double>(0, 0);
K[5] = complex<double>(0, 0);
K[6] = complex<double>(0, 0);
K[7] = complex<double>(0, 0);

FFT(K, K + 8);
IFFT(K, K + 8);
 


源码:

#ifndef _YYR_FFT_H_
#define _YYR_FFT_H_

#include <math.h>

using namespace std;

namespace YYR_FFT
{
//定义复数的数据类型,只能是double或者long double,不能编译为64位代码
typedef double ComplexType; //默认使用double
typedef complex<ComplexType> Complex;
//π常数定义
const ComplexType PI = 3.14159265358979323846;
/*----------------如下是提供给内部调用的子函数----------------*/

//检查目标区间的元素的个数是否满足2^M|M >= 2
template<typename T>
inline bool CheckRange(const T Begin, const T End)
{
if(!(Begin < End))
{
return false;
}

size_t N = End - Begin;

if (N < 2)
{
return false;
}

while (N > 1)
{
if (N % 2 > 0)
{
return false;
};
N /= 2;
}

return true;
}
//
template<typename T>
inline void CompareAndSwap(T Left, T Right)
{
if (Left != Right)
{
swap(*Left, *Right);
}
}
//
//计算W值
inline Complex CalculateW(const size_t K, const size_t N)
{
ComplexType Alpha = -2 * PI * K / N;
return Complex(cos(Alpha), sin(Alpha));
};
//计算W值,用于逆变换
inline Complex CalculateRW(const size_t K, const size_t N)
{
ComplexType Alpha = 2 * PI * K / N;
return Complex(cos(Alpha), sin(Alpha));
};
//
template<typename T>
void _FFT(T Begin, T End)
{
typename T::value_type X1;
typename T::value_type X2;

size_t N = End - Begin;

T Left = Begin;
T Right;
for (size_t I = 0; I < N / 2; ++I, Left += 2)
{
Right = Left + 1;

X1 = *Left;
X2 = *Right;
*Left = X1 + X2;
*Right = X1 - X2;
}

for (size_t GroupLength = 4; GroupLength <= N; GroupLength *= 2)
{
size_t GroupTotal = N / GroupLength;
size_t HalfGroupLen = GroupLength / 2;

for (size_t I = 0; I < GroupTotal; ++I)
{
Left = Begin + I * GroupLength;
for (size_t J = 0; J < HalfGroupLen; ++J, ++Left)
{
Right = Left + HalfGroupLen;

X1 = *Left;
X2 = (*Right) * CalculateW(J, GroupLength);

*Left = X1 + X2;
*Right = X1 - X2;
}
}
}
}
//
template<typename T>
void _RaderSort(T Begin, T End)
{
size_t TotalLength = (End - Begin);
size_t N = TotalLength / 2;
size_t I = 1;
size_t J = N;
size_t NN;

--TotalLength;
for (; I < TotalLength; ++I)
{
if (I < J)
{
swap(*(Begin+I), *(Begin+J));
}

NN = N;
while (true)
{
if (NN > J)
{
J += NN;
break;
}
else
{
J -= NN;
NN /= 2;
}
}
}
}

/*---------------如下是提供给外部调用的函数----------------*/

// 计算区间内(左闭右开)的复数的离散傅里叶变换(按时间变换DIT-FFT),使用快速傅里叶变换算法,复杂度如下
// O(N) = N * LogN.
// 约束: 目标区间的复数的个数要等于2^M(M >= 1). 否则变换失败返回false.
// 变换成功后返回true,并且在目标区间保存结果
template<typename T>
bool FFT(T Begin, T End)
{
if (CheckRange(Begin, End) == false)
{
return false;
}

RaderSort(Begin, End);

_FFT(Begin, End);

return true;
}
//
// 计算区间内(左闭右开)的复数的离散傅里叶逆变换(IFFT),使用快速傅里叶变换算法,复杂度如下
// O(N) = N * LogN.
// 约束: 目标区间的复数的个数要等于2^M(M >= 1). 否则变换失败返回false.
// 变换成功后返回true,并且在目标区间保存结果
template<typename T>
bool IFFT(T Begin, T End)
{
if (CheckRange(Begin, End) == false)
{
return false;
}

TranToConjugate(Begin, End);

RaderSort(Begin, End);

_FFT(Begin, End);

TranToConjugate(Begin, End);

size_t N = End - Begin;
for (T Iter = Begin; Iter < End; ++Iter)
{
(*Iter) /= N;
}

return true;
}

//雷德排序算法,对目标区间(左闭右开)元素进行倒位序排序,复杂度位O(N) = N;
template<typename T>
bool RaderSort(T Begin, T End)
{
if (CheckRange(Begin, End) == false)
{
return false;
}

_RaderSort(Begin, End);
return true;
}

//把目标区间(左闭右开)内的复数修改为共轭复数
template<typename T>
void TranToConjugate(T Begin, T End)
{
for (T Iter = Begin; Iter < End; ++Iter)
{
Iter->_Val[1] = 0 - (Iter->_Val[1]);
}
}
//全遍历算法的傅里叶变换,复杂度为
//O(N) = N * N
template<typename T, typename T2>
void F(T Begin, T End, T2& V)
{
size_t Total = End - Begin;
Complex Temp;
for (size_t K = 0; K < Total; ++K)
{
Temp = Complex(0, 0);
for (size_t N = 0; N < Total; ++N)
{
Temp += ((*(Begin + N)) * (CalculateW(K*N, Total)));
}
V.push_back(Temp);
}
}
//全遍历算法的傅里叶逆变换,复杂度为
//O(N) = N * N
template<typename T, typename T2>
void IF(T Begin, T End, T2& V)
{
size_t Total = End - Begin;
Complex Temp;
for (size_t K = 0; K < Total; ++K)
{
Temp = Complex(0, 0);
for (size_t N = 0; N < Total; ++N)
{
Temp += ((*(Begin + N)) * (CalculateRW(K*N, Total)));
}
Temp /= Total;
V.push_back(Temp);
}
}
}

#endif // !_YYR_FFT_H_