FFT多项式乘法加速

时间:2022-04-17 07:21:06

FFT基本操作。。。讲解请自己看大学信号转置系列。。。

15-5-30更新:改成结构体的,跪烂王学长啊啊啊啊机智的static。。。

 #include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<queue>
#include<cstring>
#define PAU putchar(' ')
#define ENT putchar('\n')
#pragma comment(linker,"/STACK:10240000,10240000")
using namespace std;
const double PI=acos(-1.0);
const int maxn=+;
struct FFT{
struct cox{
double r,i;cox(double _r = 0.0,double _i = 0.0){r=_r;i=_i;}
cox operator +(const cox &b){return cox(r+b.r,i+b.i);}
cox operator -(const cox &b){return cox(r-b.r,i-b.i);}
cox operator *(const cox &b){return cox(r*b.r-i*b.i,r*b.i+i*b.r);}
}f[maxn];int len,lenx;
void init(int*s,int L,int Len){
len=L;lenx=Len;
for(int i=;i<L;i++) f[i]=cox(s[L--i],);
for(int i=L;i<Len;i++) f[i]=cox(,);return;
}
void change(){
for(int i=,j=lenx>>;i<lenx-;i++){
if(i<j)swap(f[i],f[j]);
int k=lenx>>;
while(j>=k) j-=k,k>>=;
if(j<k) j+=k;
} return;
}
void cal(int tp){
change();
for(int h=;h<=lenx;h<<=){
double tr=-tp**PI/h;
cox wn(cos(tr),sin(tr));
for(int j=;j<lenx;j+=h){
cox w(,);
for(int k=j;k<j+(h>>);k++){
cox u=f[k],t=w*f[k+(h>>)];
f[k]=u+t;f[k+(h>>)]=u-t;w=w*wn;
}
}
} if(tp==-) for(int i=;i<lenx;i++) f[i].r/=lenx;return;
}
};
void mul(int*s1,int*s2,int L1,int L2,int&L,int*ans){
L=;while(L<L1<<||L<L2<<) L<<=;
static FFT a,b;a.init(s1,L1,L);b.init(s2,L2,L);a.cal();b.cal();
for(int i=;i<L;i++) a.f[i]=a.f[i]*b.f[i];a.cal(-);
for(int i=;i<L;i++) ans[i]=(int){a.f[i].r+0.5};return;
}
int s1[maxn>>],s2[maxn>>],ans[maxn],L1,L2,L;
void init(){
char ch;int tot;
do ch=getchar(); while(!isdigit(ch));tot=;
while(isdigit(ch)){s1[tot++]=ch-'';ch=getchar();}L1=tot;
do ch=getchar(); while(!isdigit(ch));tot=;
while(isdigit(ch)){s2[tot++]=ch-'';ch=getchar();}L2=tot;
mul(s1,s2,L1,L2,L,ans);
return;
}
void work(){
for(int i=;i<L;i++){
ans[i+]+=ans[i]/;ans[i]%=;
} L=L1+L2-;return;
}
void print(){
while(ans[L]<=&&L>) L--;
for(int i=L;i>=;i--) putchar(ans[i]+'');
return;
}
int main(){init();work();print();return ;}

原来写的丑哭了TAT

 #include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<queue>
#include<cstring>
#define PAU putchar(' ')
#define ENT putchar('\n')
using namespace std;
const double PI=acos(-1.0);
struct complex{
double r,i;
complex(double _r = 0.0,double _i = 0.0){r=_r;i=_i;}
complex operator +(const complex &b){return complex(r+b.r,i+b.i);}
complex operator -(const complex &b){return complex(r-b.r,i-b.i);}
complex operator *(const complex &b){return complex(r*b.r-i*b.i,r*b.i+i*b.r);}
};
void change(complex y[],int len){
int i,j,k;
for(i=,j=len/;i<len-;i++){
if(i<j)swap(y[i],y[j]);
k=len/;
while(j>=k) j-=k,k/=;
if(j<k) j+=k;
} return;
}
void fft(complex y[],int len,int on){
change(y,len);
for(int h=;h<=len;h<<=){
complex wn(cos(-on**PI/h),sin(-on**PI/h));
for(int j=;j<len;j+=h){
complex w(,);
for(int k=j;k<j+h/;k++){
complex u=y[k],t=w*y[k+h/];
y[k]=u+t;
y[k+h/]=u-t;
w=w*wn;
}
}
}
if(on==-) for(int i=;i<len;i++) y[i].r/=len;
return;
}
const int MAXN=+;
complex x1[MAXN],x2[MAXN];
char str1[MAXN>>],str2[MAXN>>];
int sum[MAXN];
inline int read(){
int x=,sig=;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')sig=-;ch=getchar();}
while(isdigit(ch))x=*x+ch-'',ch=getchar();
return x*=sig;
}
inline void write(int x){
if(x==){putchar('');return;}if(x<)putchar('-'),x=-x;
int len=,buf[];while(x)buf[len++]=x%,x/=;
for(int i=len-;i>=;i--)putchar(buf[i]+'');return;
}
int len,len1,len2;
void init(){
scanf("%s%s",str1,str2);
len1=strlen(str1),len2=strlen(str2),len=;
return;
}
void work(){
while(len<len1<<||len<len2<<) len<<=;
for(int i=;i<len1;i++) x1[i]=complex(str1[len1--i]-'',);
for(int i=len1;i<len;i++) x1[i]=complex(,);
for(int i=;i<len2;i++) x2[i]=complex(str2[len2--i]-'',);
for(int i=len2;i<len;i++) x2[i]=complex(,);
fft(x1,len,);fft(x2,len,);
for(int i=;i<len;i++) x1[i]=x1[i]*x2[i];
fft(x1,len,-);
for(int i=;i<len;i++) sum[i]=(int)(x1[i].r+0.5);
for(int i=;i<len;i++){
sum[i+]+=sum[i]/;
sum[i]%=;
} len=len1+len2-;return;
}
void print(){
while(sum[len]<=&&len>) len--;
for(int i=len;i>=;i--) putchar(sum[i]+'');
return;
}
int main(){init();work();print();return ;}

搜索

复制