FFT初步学习小结

时间:2021-08-25 04:08:05

FFT其实没什么需要特别了解的,了解下原理,(特别推荐算法导论上面的讲解),模板理解就行了。重在运用吧。

处理过程中要特别注意精度。

先上个练习的地址吧:

http://vjudge.net/vjudge/contest/view.action?cid=53596#overview

Problem A A * B Problem Plus

A*B的大数乘法,似乎大数模拟乘法不行的,得用FFT优化到nlogn,将一个数AnAn-1....A1A0,看做An*10^n+An-1*10^n-1+....A1*10+A0*10^0,这样就可以将两个数相乘当成多项式乘法了。

代码:

 #include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <cmath>
#include <queue>
#include <map>
#include <set>
#include <complex>
#define pb push_back
#define in freopen("solve_in.txt", "r", stdin);
#define out freopen("solve_out.txt", "w", stdout);
#define pi (acos(-1.0))
#define bug(x) printf("line %d :>>>>>>>>>>>>>>>\n", x);
#define pb push_back
using namespace std;
#define esp 1e-8 typedef complex<double> CD;
const int maxn = +;
char s1[maxn], s2[maxn]; struct Complex{
double x, y;
Complex(){}
Complex(double x, double y):x(x), y(y){}
};
Complex operator + (const Complex &a, const Complex &b){
Complex c;
c.x = a.x+b.x;
c.y = a.y+b.y;
return c;
}
Complex operator - (const Complex &a, const Complex &b){
Complex c;
c.x = a.x-b.x;
c.y = a.y-b.y;
return c;
}
Complex operator * (const Complex &a, const Complex &b){
Complex c;
c.x = a.x*b.x-a.y*b.y;
c.y = a.x*b.y+a.y*b.x;
return c;
} inline void FFT(vector<Complex> &a, bool inverse)
{
int n = a.size();
for(int i = , j = ; i < n; i++)
{
if(j > i)
swap(a[i], a[j]);
int k = n;
while(j & (k>>=)) j &= ~k;
j |= k;
}
double PI = inverse ? -pi : pi;
for(int step = ; step <= n; step <<= )
{
double alpha = *PI/step;
Complex wn(cos(alpha), sin(alpha));
for(int k = ; k < n; k += step)
{
Complex w(, );
for(int Ek = k; Ek < k+step/; Ek++)
{
int Ok = Ek + step/;
Complex u = a[Ek];
Complex t = a[Ok]*w;
a[Ok] = u-t;
a[Ek] = u+t;
w = w*wn;
}
}
}
if(inverse)
for(int i = ; i < n; i++)
a[i].x = (a[i].x/n);
}
vector<double> operator * (const vector<double> &v1, const vector<double> &v2)
{
int S1 = v1.size(), S2 = v2.size();
int S = ;
while(S < S1+S2) S <<= ;
vector<Complex> a(S), b(S);
for(int i = ; i < S; i++)
a[i].x = a[i].y = b[i].x = b[i].y = 0.0;
for(int i = ; i < S1; i++)
a[i].x = v1[i];
for(int i = ; i < S2; i++)
b[i].x = v2[i];
FFT(a, false);
FFT(b, false);
for(int i = ; i < S; i++)
a[i] = a[i] * b[i];
FFT(a, true);
vector<double> res(S1+S2-, 0.0);
for(int i = ; i < S1+S2-; i++)
res[i] = a[i].x;
return res;
}
int ans[maxn*+];
vector<double > v1, v2;
int main()
{ while(scanf("%s%s", s1, s2) != -)
{ v1.clear();
v2.clear();
int len1 = strlen(s1);
int len2 = strlen(s2);
for(int i = ; s1[i]; i++)
v1.pb((double)(s1[len1--i]-''));
for(int i = ; s2[i]; i++)
v2.pb((double)(s2[len2--i]-''));
v1 = v1*v2;
memset(ans, , sizeof(ans));
int carry = , top = ;
for(int i = ; i < v1.size(); i++){
carry += (int)(v1[i]+0.5);
ans[top++] = carry%;
carry /= ;
}
while(carry){
ans[top++] = carry%;
carry /= ;
}
while(top)
{
if(ans[top])
break;
top--;
}
for(int i = top; i >= ; i--)
printf("%d", ans[i]);
cout<<endl;
}
return ;
}

Problem B 3-idiots

题意:给出一系列边长,问从中选出3条边构成三角形的概率。

分析:首先求出任意两边之和相同的取法有多少种,用每种长度边的数目作系数,FFT即可,然后就是去重了。

对所有边长排序,对于第i条边,sum[i]表示两边之和大于a[i]的边的对数,减去特殊情况,

1.两边均大于a[i],

2,两边一个大于a[i],

3,两边中包含a[i]。

代码:

 #include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <cmath>
#include <queue>
#include <map>
#include <set>
#include <complex>
#define pb push_back
#define in freopen("solve_in.txt", "r", stdin);
#define out freopen("solve_out.txt", "w", stdout);
#define pi (acos(-1.0))
#define bug(x) printf("line %d :>>>>>>>>>>>>>>>\n", x);
#define pb push_back
using namespace std;
#define esp 1e-8 typedef complex<double> CD;
const int maxn = +;
struct Complex
{
double x, y;
Complex() {}
Complex(double x, double y):x(x), y(y) {}
};
Complex operator + (const Complex &a, const Complex &b)
{
Complex c;
c.x = a.x+b.x;
c.y = a.y+b.y;
return c;
}
Complex operator - (const Complex &a, const Complex &b)
{
Complex c;
c.x = a.x-b.x;
c.y = a.y-b.y;
return c;
}
Complex operator * (const Complex &a, const Complex &b)
{
Complex c;
c.x = a.x*b.x-a.y*b.y;
c.y = a.x*b.y+a.y*b.x;
return c;
}
inline void FFT(vector<Complex> &a, bool inverse)
{
int n = a.size();
for(int i = , j = ; i < n; i++)
{
if(j > i)
swap(a[i], a[j]);
int k = n;
while(j & (k>>=)) j &= ~k;
j |= k;
}
double PI = inverse ? -pi : pi;
for(int step = ; step <= n; step <<= )
{
double alpha = *PI/step;
Complex wn = Complex(cos(alpha), sin(alpha));
for(int k = ; k < n; k += step)
{
Complex w(, );
for(int Ek = k; Ek < k+step/; Ek++)
{
int Ok = Ek + step/;
Complex u = a[Ek];
Complex t = w*a[Ok];
a[Ek] = u+t;
a[Ok] = u-t;
w = wn*w;
} }
}
if(inverse)
for(int i = ; i < n; i++)
a[i].x = a[i].x/n;
} vector<double> operator * (const vector<double> &v1, const vector<double> &v2)
{
int S1 = v1.size(), S2 = v2.size();
int S = ;
while(S < S1+S2) S <<= ;
vector<Complex> a(S, Complex(0.0, 0.0)), b(S, Complex(0.0, 0.0));
for(int i = ; i < S1; i++)
a[i].x = v1[i];
for(int i = ; i < S2; i++)
b[i].x = v2[i];
FFT(a, false);
FFT(b, false);
for(int i = ; i < S; i++)
a[i] = a[i] * b[i];
FFT(a, true);
vector<double> res(S1+S2-, 0.0);
for(int i = ; i < S1+S2-; i++)
res[i] = a[i].x;
return res;
}
int a[maxn];
typedef long long LL;
LL sum[maxn<<];
typedef long long LL;
int n;
int main()
{
in
int T;
for(int t = scanf("%d", &T); t <= T; t++)
{
LL ans =;
scanf("%d", &n);
int mx = ;
for(int i = ; i < (maxn<<); i++)
sum[i] = ;
for(int i = ; i < n; i++)
{
scanf("%d", &a[i]);
sum[a[i]]++;
mx = max(a[i], mx);
}
int tmp = ;
while(tmp < *(mx+)) tmp <<= ;//上界 是mx + 1!
vector<Complex> v2(tmp, Complex(0.0, 0.0));
for(int i = ; i <= mx; i++)
v2[i] = (Complex((double)sum[i], 0.0));
FFT(v2, ); for(int i = ; i < v2.size(); i++)
{
v2[i] = v2[i]*v2[i];
}
FFT(v2, );
for(int i = ; i <= mx*; i++)
{
sum[i] = (LL)(v2[i].x+0.5);
}
for(int i = ; i < n; i++)
sum[a[i]*]--;
for(int i = ; i <= *mx; i++)
sum[i] /= ;
for(int i = ; i <= *mx; i++)
sum[i] += sum[i-];
sort(a, a+n);
for(int i = ; i < n; i++)
{
ans += (sum[mx*]-sum[a[i]]);
ans -= ((LL)(n-i-)*(n-i-)/ + (n-) + (LL)i*(n--i));//注意用long long保证精度
}
printf("%.7f\n", 1.0*ans/((LL)n*(n-)*(n-)/));
}
return ;
}

Problem C K-neighbor substrings

给定A,B两个01串,求A中和B串长度相同的且哈密顿距离不超过K的不同子串个数。

分析:

求一个串和另一个串的哈密顿距离,也就是对应位置字符不同的个数,长度-同为1-同为0就是结果了。怎么样找同为1或者同为0的个数呢?由于当且仅当同为1,相乘结果为1,所以将B串反转,那么作FFT,对应系数为1表示相应位置同为1,然后统计系数的和就是两串同为1个数。同为0的个数计算也很简单,只需要将B串翻转一下,0变1,1变0,同样FFT就行了。

由于题目要求不同子串的个数,利用Hash就可以了。

代码:

 #include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <cmath>
#include <queue>
#include <map>
#include <set>
#include <bitset>
#include <complex>
#define pb push_back
#define in freopen("solve_in.txt", "r", stdin);
#define out freopen("solve_out.txt", "w", stdout);
#define pi (acos(-1.0))
#define bug(x) printf("line %d :>>>>>>>>>>>>>>>\n", x);
#define pb push_back
#define esp 1e-8
using namespace std; typedef long long LL;
typedef unsigned long long ULL;
typedef map<ULL, int> MPLL;
const int maxn = +;
const int bb = ;
unsigned long long Hash[maxn], sum[maxn];
MPLL mps;
bitset<maxn> b[];
int n, m, K;
struct Complex
{
double x, y;
Complex() {}
Complex(double x, double y):x(x), y(y) {}
};
Complex operator + (const Complex &a, const Complex &b)
{
Complex c;
c.x = a.x+b.x;
c.y = a.y+b.y;
return c;
}
Complex operator - (const Complex &a, const Complex &b)
{
Complex c;
c.x = a.x-b.x;
c.y = a.y-b.y;
return c;
}
Complex operator * (const Complex &a, const Complex &b)
{
Complex c;
c.x = a.x*b.x-a.y*b.y;
c.y = a.x*b.y+a.y*b.x;
return c;
}
inline void FFT(vector<Complex> &a, bool inverse)
{
int n = a.size();
for(int i = , j = ; i < n; i++)
{
if(j > i)
swap(a[i], a[j]);
int k = n;
while(j & (k>>=)) j &= ~k;
j |= k;
}
double PI = inverse ? -pi : pi;
for(int step = ; step <= n; step <<= )
{
double alpha = *PI/step;
Complex wn = Complex(cos(alpha), sin(alpha));
for(int k = ; k < n; k += step)
{
Complex w(, );
for(int Ek = k; Ek < k+step/; Ek++)
{
int Ok = Ek + step/;
Complex u = a[Ek];
Complex t = w*a[Ok];
a[Ek] = u+t;
a[Ok] = u-t;
w = wn*w;
} }
}
if(inverse)
for(int i = ; i < n; i++)
a[i].x = a[i].x/n+esp;
}
int x[maxn]; void go(int S1, int S2)
{
int S = ;
while(S < S1+S2) S <<= ;
vector<Complex> sa(S, Complex(0.0, 0.0)), sb(S, Complex(0.0, 0.0));
for(int i = ; i < S1; i++)
sa[i].x = b[][i];
for(int i = ; i < S2; i++)
sb[i].x = b[][i];
FFT(sa, false);
FFT(sb, false);
for(int i = ; i < S; i++)
sa[i] = sa[i] * sb[i];
FFT(sa, true);
for(int i = ; i <= n-m; i++)
{
x[i] += round(sa[i+m-].x);
}
}
void solve()
{
int ans = ;
go(n, m);
b[].flip();
b[].flip();
go(n, m);
for(int i = ; i <= n-m; i++)
{
if(m-x[i] <= K)
{
int ok = ;
for(int k = ; k < ; k++)
{
ULL tmp = Hash[i]-Hash[i+m]*sum[m];
if(mps.count(tmp))
{
ok++;
}
}
if(ok == )
{
ans++;
mps[Hash[i]-Hash[i+m]*sum[m]] = ;
// mps[1][Hash[1][0][i]-Hash[1][0][i+m]*sum[1][m]] = 1;
}
}
}
printf("%d\n", ans);
}
char A[maxn], B[maxn];
void pre()
{
sum[] = ;
for(int i = ; i < maxn; i++)
sum[i] = sum[i-]*bb;
} int main()
{ pre();
int kase = ;
while(scanf("%d", &K), K >= )
{
printf("Case %d: ", ++kase);
mps.clear();
Hash[n] = ;
scanf("%s%s", A, B);
n = strlen(A), m = strlen(B);
for(int i = n-; i >= ; i--)
{
x[i] = ;
int t = A[i]-'a';
b[][i] = t;
Hash[i] = Hash[i+]*bb+t;
}
for(int i = m-; i >= ; i--)
{
int t = B[i]-'a';
b[][m--i] = t;
}
solve();
}
return ;
}

Problem D Linear recursive sequence

f(k)  = 0(k <=0)

f(k) = af(k-p)+bf(k-q) 

a,b,k<=10^9, p < q <= 10^4

分析:

具体用到了叉姐的论文《矩阵乘法递推的优化》,其实总结起来就是一个式子,W^i = b1*W^k-1+b2*W^k-2+......bk*E,也就是任何一个W^i都可以表示成E,W^1,W^2,

.....W^k-1的线性组合。求f(n)也就是求出f(n)对应的W^i,(i = n-k+1), 实际就是多次求一个多项式乘法,每次复杂度O(k^2), 这样可以将矩阵乘法优化到k^2log(n)。

但是本题范围较大即使k^2log(n)也不够优化,而且递推系数只有2个,其他都为0,因此可以用FFT加速多项式乘法,O(qlogqlogn)已经很优化了。

代码:

 #pragma comment(linker, "/STACK:16777216")
#include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <cmath>
#include <queue>
#include <map>
#include <set>
#include <bitset>
#include <complex>
#define inf 0x0f0f0f0f
#define pb push_back
#define in freopen("solve_in.txt", "r", stdin);
#define out freopen("solve_out.txt", "w", stdout);
#define pi (acos(-1.0))
#define bug(x) printf("line %d :>>>>>>>>>>>>>>>\n", x);
#define pb push_back
#define esp 1e-8
typedef long long LL;
using namespace std; const int M = ;
int n, p, q, a, b; struct Complex {
double x, y;
Complex() {}
Complex(double x, double y):x(x), y(y) {}
Complex operator + (const Complex &o)const {
return Complex(x+o.x, y+o.y);
}
Complex operator - (const Complex &o)const {
return Complex(x-o.x, y-o.y);
}
Complex operator * (const Complex &o)const {
return Complex(x*o.x-y*o.y, y*o.x+x*o.y);
}
};
void add(double &x, double y) {
long long a = (LL)(x+.), b = (LL)(y+.);
a += b;
a%=M;
x = (double)a;
}
void FFT(vector<Complex> &a, bool inverse) {
int nn = a.size();
for(int i =, j = ; i < nn; i++) {
if(j > i)
swap(a[i], a[j]);
int k = nn;
while(j &(k >>= )) j &=~k;
j |= k;
}
double PI = inverse ? -pi : pi;
for(int step = ; step <= nn; step <<= ) {
Complex wn(cos(PI*/step), sin(PI*/step));
for(int j = ; j < nn; j += step) {
Complex w(, );
for(int Ek = j; Ek < j+step/; Ek++) {
int Ok = Ek + step/;
Complex u = a[Ek];
Complex t = w*a[Ok];
a[Ek] = u+t;
a[Ok] = u-t;
w = w*wn;
}
}
}
if(inverse)
for(int i = ; i < nn; i++)
a[i].x = a[i].x/nn;
}
vector<double> operator *(const vector<double> &v1, const vector<double> &v2) {
int S1 = v1.size();
int S2 = v2.size();
int S = ;
while(S < S1+S2) S <<= ;
vector<Complex> aa(S, Complex(0.0, 0.0)), ab(S, Complex(0.0, 0.0));
for(int i = ; i < S1; i++)
aa[i].x = v1[i];
for(int i = ; i < S2; i++)
ab[i].x = v2[i];
FFT(aa, false);
FFT(ab, false);
for(int i = ; i < S; i++)
aa[i] = aa[i]*ab[i];
FFT(aa, true);
vector<double> res(S1+S2-, 0.0);
for(int i = ; i < S1+S2-; i++) {
res[i] = aa[i].x;
// cout<<aa[i].y<<endl;
// cout<<res[i]<<endl;
}
for(int i = S1+S2-; i >= q; i--) {
add(res[i-p], a*res[i]);
add(res[i-q], b*res[i]);
}
res.resize(q);
return res;
}
const int maxn = (int)3e4+; int h[maxn];
void calPre() {
memset(h, , sizeof h);
h[] = ;
for(int i = ; i < *q-; i++) {
if(i<=p) h[i] = (h[i]+a)%M;
else h[i] = (h[i]+a*h[i-p])%M;
if(i-q <= ) h[i] = (h[i]+b)%M;
else h[i] = (h[i]+b*h[i-q])%M;
}
}
int main() { while(scanf("%d%d%d%d%d", &n, &a, &b, &p, &q) == ) {
a %= M;
b %= M;
calPre();
if(n < q) {
cout<<h[n]<<endl;
continue;
}
n=n-q+;
vector<double> Ma(q, 0.0), base(q, 0.0);
Ma[] = 1.0;
base[] = 1.0;
while(n) {
if(n&) {
Ma = Ma*base;
}
base = base*base;
n>>=;
}
//// Ma = Ma*base;
//// cout<<base[0]<<base[1]<<endl;
//// cout<<Ma[0]<<Ma[1]<<endl;
double res = 0;
for(int i = ; i < q; i++) {
add(res, (Ma[i]*h[q-+i]));
}
cout<<(int)(res+.)<<endl;
}
return ;
}

HDU G++超时,C++2000ms,真是无力了。还有连叉姐标程C++都会超时,G++才能过。不过好像极端数据,标程和我的代码都不能再规定时间跑完,大概15s左右?所以说数据还是很水的。

Problem E Cipher Message 3

题意:给定n个8二进制串,和m个8位二进制串构成2个序列,要求在n个串的序列中找到连续的一段,使得和m个二进制串匹配,匹配的意思是两个序列中对应的二进制串前7位相同,后面一位可以不同,但是会有额外的花费,求最后花费最小的一个匹配序列,而且要求该序列在n中尽量靠左。

分析:

训练赛时遇到的一道题,当时不会FFT,甚至不知道这玩意在竞赛里还能干啥,首先利用KMP找出所有能够匹配的位置,当时知道怎么找花费最小的了。利用FFT求出两个串在各个位置开始的哈密顿距离就可以了。转化和上面C题一样。

代码:

 #include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <queue>
#include <ctime>
#include <map>
#include <set>
#include <cmath>
#include <bitset>
#define pb push_back
#define in freopen("solve_in.txt", "r", stdin);
#define out freopen("solve_out.txt", "w", stdout);
#define pi (acos(-1.0))
#define bug(x) printf("line %d :>>>>>>>>>>>>>>>\n", x);
#define pb push_back
using namespace std;
#define esp 1e-8 const int maxn = +;
int sa[maxn], sb[maxn];
int n, m;
int x[maxn];
bitset<maxn> da, db;
struct Complex
{
double x, y;
Complex() {}
Complex(double x, double y):x(x), y(y) {}
inline Complex operator + (const Complex &b)const
{
return Complex(x+b.x, y+b.y);
}
inline Complex operator - ( const Complex &b)const
{
return Complex(x-b.x, y-b.y);
}
inline Complex operator * (const Complex &b)const
{
return Complex(x*b.x-y*b.y, x*b.y+y*b.x);
}
}; inline void FFT(vector<Complex> &a, bool inverse)
{
int n = a.size();
for(int i = , j = n/; i < n-; i++)
{
if(j > i)
swap(a[i], a[j]);
int k = n>>;
while(j >= k)
{
j-=k;
k >>= ;
}
j += k;
}
double PI = inverse ? -pi : pi;
for(int step = ; step <= n; step <<= )
{
double alpha = *PI/step;
Complex wn(cos(alpha), sin(alpha)); for(int k = ; k < n; k+=step)
{
Complex w(, );
for(int j = k; j < k+step/; j++)
{ Complex u = a[j];
Complex t = w*a[j+step/];
a[j] = u+t;
a[j+step/] = u-t;
w = w*wn;
}
}
}
if(inverse)
for(int i = ; i < n; i++)
a[i].x = a[i].x/n;
}
void go()
{
int S1 = n, S2 = m;
int tmp = max(S1, S2);
int S = ;
while(S < S1+S2) S <<= ;
vector<Complex> a(S, Complex(0.0, 0.0)), b(S, Complex(0.0, 0.0));
for(int i = ; i < S1; i++)
a[i].x = da[i];
for(int i = ; i < S2; i++)
b[i].x = db[i];
FFT(a, false);
FFT(b, false);
for(int i = ; i < S; i++)
{
Complex c;
c.x = (a[i].x*b[i].x)-(a[i].y*b[i].y);
c.y = (a[i].x*b[i].y)+(a[i].y*b[i].x);
a[i] = c;
// a[i] = a[i]*b[i];
}
FFT(a, true);
for(int i = ; i <= n-m; i++)
x[i] += round(esp+a[i+m-].x);
}
int f[maxn];
int match[maxn];
int cnt = ;
void getFail(int a[], int n)
{
f[] = f[] = ;
int j;
for(int i = ; i < n; i++)
{
j = f[i];
while(j && a[i] != a[j]) j = f[j];
f[i+] = (a[i] == a[j] ? j+ : );
}
}
void KMP(int T[], int P[], int n, int m)
{
getFail(P, m);
int j = ;
for(int i = ; i < n; i++)
{
while(j && T[i] != P[j]) j = f[j];
if(T[i] == P[j]) j++;
if(j == m)
{
match[cnt++] = i-m+;
}
}
} char bit[];
void input()
{
for(int i = ; i < n; i++)
{
scanf("%s", bit);
da[i] = bit[]-'';
for(int ii = ; ii < ; ii++)
sa[i] = sa[i]*+bit[ii]-'';
}
for(int i = ; i < m; i++)
{
scanf("%s", bit);
db[m--i] = bit[]-'';
for(int ii = ; ii < ; ii++)
sb[i] = sb[i]*+bit[ii]-'';
}
sa[n] = sb[m] = -;
} void solve()
{
srand(time());
KMP(sa, sb, n, m);
if(cnt == )
{
puts("No");
return;
}
puts("Yes");
go();
da.flip();
db.flip();
go();
int ans = m;
int pos = ;
for(int i = ; i < cnt; i++)
{
if(ans == )
break;
int tmp = match[i];
if(m-x[tmp] < ans)
ans = m-x[tmp], pos = tmp;
} printf("%d %d\n", ans, pos+);
}
int main()
{ scanf("%d%d", &n,&m);
input();
solve();
return ;
}

其实还遇到一道题,里面也用到了FFT,HDU 4954 Permanent,不过先挖个坑。

总结:FFT看上去很繁琐,了解了发现其实也就是一个很好利用的工具。很多东西要尽量主动去学习,如果之前了解过FFT,说不定遇到关键问题就不会像上面卡壳了。对于遇到的新的东西,也要沉下心来,不要只想着切自己学过的,感兴趣的知识,那绝不是进步的方法。