【51nod 1028】 大数乘法 V2 【FFT/NTT】
FFT
#include<cstdio>
#include<cstring>
#include<cmath>
#include<complex>
#define in_ inline
#define re_ return
#define op_ operator
#define tt_ template<typename x>
#define st_ static
#define inc(l, i, r) for(i=l; i<r; ++i)
typedef long long ll;
typedef double db;
const db eps=0.000001;
const ll mxn=300000;
ll c[mxn];
struct com
{
db r, i;
com(){r=i=0;}
com(db a, db b){r=a, i=b;}
tt_ in_ com op_*(x a)
{re_ com(r*a, i*a);}
in_ com op_*(com a)
{re_ com(r*a.r-i*a.i, r*a.i+i*a.r);}
tt_ in_ com op_+(x a)
{re_ com(r+a, i);}
in_ com op_+(com a)
{re_ com(r+a.r, i+a.i);}
tt_ in_ com& op_*=(x a)
{re_ *this=*this*a;}
tt_ in_ com& op_+=(x a)
{re_ *this=*this+a;}
tt_ in_ com& op_=(x a)
{re_ *this=com(a, 0);}
} a[mxn], b[mxn];
struct io
{
tt_ in_ io& op_& (x* a)
{
st_ ll c;
for(;(c=getchar())<48||57<c;)
if(c==EOF) re_ *this;
for(;c>47&&58>c; c=getchar())
*a++=c-48;
*a=-1;
re_ *this;
}
in_ io& op_| (char* s)
{
for(;*s; s++) putchar(*s);
puts("");
re_ *this;
}
in_ io& op_| (ll a)
{
re_ printf("%I64d ", a), *this;
}
} io;
#define sw(a, b) (swer=a, a=b, b=swer)
namespace fft
{
ll n, m; com w[mxn], ar1[mxn], ar2[mxn];
in_ ll init(ll sgn)
{
for(ll i=0; i<n; ++i) w[i]=com(
cos(acos(-1)*2*i*sgn/n),
sin(acos(-1)*2*i*sgn/n));
}
in_ int dft(com* a)
{
st_ ll i, j, k; com *f, *g, *swer;
inc((f=ar1, g=ar2, 0), i, n) f[i]=a[i];
for(k=1; k<n; k<<=1, sw(f, g))
for(i=0; i<m; i+=k) inc(0, j, k)
g[i<<1 |j]=f[i|j]+w[j*m/k]*f[m|i|j],
g[i<<1|k|j]=f[i|j]+w[j*m/k]*f[m|i|j]*-1;
inc(0, i, n) a[i]=f[i];
}
}
int main()
{
using namespace fft;
ll i, j, k; io&a&b; com swer;
for(i=0; a[i].r>=0; ++i);
for(j=0; b[j].r>=0; ++j);
a[i]=b[j]=0;
inc(0, k, i>>1) sw(a[k], a[i-k-1]);
inc(0, k, j>>1) sw(b[k], b[j-k-1]);
for(n=1; n<i||n<j; n<<=1);
m=n, n<<=1, init(1), dft(a), dft(b);
inc(0, i, n) a[i]*=b[i];
init(-1), dft(a);
inc(0, i, n) c[i]=(ll)floor(a[i].r/n+eps);
inc(0, i, n) c[i+1]+=c[i]/10, c[i]%=10;
for(i=n; !c[i]; --i);
for(;~i; --i) putchar(c[i]+48);
re_ 0;
}
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<sys/time.h>
using namespace std;
typedef long long ll;
const ll mxn=1<<24, p=998244353;
ll n, a1[mxn], a2[mxn], a3[mxn], b[mxn], c[mxn], Ome[mxn];
struct timeval x, y;
ll disPow(ll a, ll b)
{
static ll r;
for(r=1; b; a=a*a%p, b>>=1)
b&1 ?r=r*a%p :0;
return r;
}
char setOme(ll Ome1)
{static ll i; for(Ome[0]=i=1; i<n; ++i) Ome[i]=Ome[i-1]*Ome1%p;}
inline char dft1(ll* a)
{
static ll i, j, k, *f, *g;
f=b, g=c;
for(i=0; i<n; ++i) f[i]=a[i];
gettimeofday(&x, 0);
for(k=1; k<n; k<<=1, swap(f, g))//考虑到k次原根
for(i=0; i<n>>1; i+=k)//枚举子多项式
for(j=0; j<k; ++j)//代入k次原根的j次幂
g[i<<1 | j]=(f[i | j]+Ome[j*n/k>>1]*f[n>>1|i | j])%p,
g[i<<1 | k|j]=(f[i | j]-Ome[j*n/k>>1]*f[n>>1|i | j])%p;
gettimeofday(&y, 0);
printf("%I64d\n", (y.tv_sec-x.tv_sec)*1000000ll+y.tv_usec-x.tv_usec);
for(i=0; i<n; ++i) a[i]=f[i]<0? f[i]+p: f[i];
}
inline char dft2(ll *a)
{
static ll i, j, k, *f, *g;
f=b, g=c;
for(i=0; i<n; ++i) f[i]=a[i];
gettimeofday(&x, 0);
for(k=1; k<n; k<<=1, swap(f, g))//~=划分为n/k个多项式
for(i=0; i<n; i+=n/k)//代入k次原根的i/(n/k)次幂=代入n次原根的i次幂
for(j=0; j<n/k>>1; ++j)//枚举第j个子多项式
g[ i>>1 | j]=(f[i | j]+Ome[i>>1]*f[i | n/k>>1|j])%p,
g[n>>1|i>>1 | j]=(f[i | j]-Ome[i>>1]*f[i | n/k>>1|j])%p;
gettimeofday(&y, 0);
printf("%I64d\n", (y.tv_sec-x.tv_sec)*1000000ll+y.tv_usec-x.tv_usec);
for(i=0; i<n; ++i) a[i]=f[i]<0? f[i]+p: f[i];
}
inline char dft3(ll *a)
{
static ll i, j, k, *f, *g, *g1, *g2, *f1, *f2, Ome1;
f=b, g=c;
for(i=0; i<n; ++i) f[i]=a[i];
gettimeofday(&x, 0);
for(k=n; k>1; k>>=1, swap(f, g))//划分为k个多项式(f)
for(i=0; i<n; i+=k)//~(f)
{
g1=g+(i>>1), g2=g+(n>>1|i>>1), f1=f+i, f2=f+(i|k>>1), Ome1=Ome[i>>1];
for(j=0; j<k>>1; ++j)
g1[j]=(f1[j]+Ome1*f2[j])%p,
g2[j]=(f1[j]-Ome1*f2[j])%p;
}
gettimeofday(&y, 0);
printf("%I64d\n", (y.tv_sec-x.tv_sec)*1000000ll+y.tv_usec-x.tv_usec);
for(i=0; i<n; ++i) a[i]=f[i]<0? f[i]+p: f[i];
}
inline char dft4(ll *a)
{
static ll i, j, k, *f, *g, *g1, *g2, *f1, *f2, Ome1;
f=b, g=c;
for(i=0; i<n; ++i) f[i]=a[i];
gettimeofday(&x, 0);
for(k=n; k>1; k>>=1, swap(f, g))//划分为k个多项式
for(i=0; i<n; i+=k)//~
{
g1=g+(i>>1), g2=g+(n>>1|i>>1), f1=f+i, f2=f+(i|k>>1), Ome1=Ome[i>>1];
for(j=0; j<k>>1; ++j, ++g1, ++g2, ++f1, ++f2)
*g1=(*f1+Ome1**f2)%p,
*g2=(*f1-Ome1**f2)%p;
}
gettimeofday(&y, 0);
printf("%I64d\n", (y.tv_sec-x.tv_sec)*1000000ll+y.tv_usec-x.tv_usec);
for(i=0; i<n; ++i) a[i]=f[i]<0? f[i]+p: f[i];
}
int main()
{
ll i;
scanf("%I64d", &n);
for(i=0; i<n; ++i) scanf("%I64d", a1+i);
memcpy(a2, a1, sizeof(a1));
setOme(disPow(3, (p-1)/n));
dft1(a1), memcpy(a1, a2, sizeof(a1));
dft2(a1), memcpy(a1, a2, sizeof(a1));
dft3(a1), memcpy(a1, a2, sizeof(a1));
dft4(a1), memcpy(a1, a2, sizeof(a1));
return 0;
}