【51nod 1028】 大数乘法 V2 【FFT/NTT】

时间:2021-04-22 14:50:17

【51nod 1028】 大数乘法 V2 【FFT/NTT】

FFT

AC代码。

#include<cstdio>
#include<cstring>
#include<cmath>
#include<complex>
#define in_ inline
#define re_ return
#define op_ operator
#define tt_ template<typename x>
#define st_ static
#define inc(l, i, r) for(i=l; i<r; ++i)
typedef long long ll;
typedef double db;

const db eps=0.000001;
const ll mxn=300000;

ll c[mxn];
struct com
{
    db r, i;
    com(){r=i=0;}
    com(db a, db b){r=a, i=b;}
    tt_ in_ com op_*(x a)
        {re_ com(r*a, i*a);}
    in_ com op_*(com a)
        {re_ com(r*a.r-i*a.i, r*a.i+i*a.r);}
    tt_ in_ com op_+(x a)
        {re_ com(r+a, i);}
    in_ com op_+(com a)
        {re_ com(r+a.r, i+a.i);}
    tt_ in_ com& op_*=(x a)
        {re_ *this=*this*a;}
    tt_ in_ com& op_+=(x a)
        {re_ *this=*this+a;}
    tt_ in_ com& op_=(x a)
        {re_ *this=com(a, 0);}
} a[mxn], b[mxn];

struct io
{
    tt_ in_ io& op_& (x* a)
    {
        st_ ll c;
        for(;(c=getchar())<48||57<c;)
            if(c==EOF) re_ *this;
        for(;c>47&&58>c; c=getchar())
            *a++=c-48;
        *a=-1;
        re_ *this;
    }
    in_ io& op_| (char* s)
    {
        for(;*s; s++) putchar(*s);
        puts("");
        re_ *this;
    }
    in_ io& op_| (ll a)
    {
        re_ printf("%I64d ", a), *this;
    }
} io;

#define sw(a, b) (swer=a, a=b, b=swer)

namespace fft
{
    ll n, m; com w[mxn], ar1[mxn], ar2[mxn];
    in_ ll init(ll sgn)
    {
        for(ll i=0; i<n; ++i) w[i]=com(
            cos(acos(-1)*2*i*sgn/n),
            sin(acos(-1)*2*i*sgn/n));
    }
    in_ int dft(com* a)
    {
        st_ ll i, j, k; com *f, *g, *swer;
        inc((f=ar1, g=ar2, 0), i, n) f[i]=a[i];
        for(k=1; k<n; k<<=1, sw(f, g))
            for(i=0; i<m; i+=k) inc(0, j, k)
                g[i<<1  |j]=f[i|j]+w[j*m/k]*f[m|i|j],
                g[i<<1|k|j]=f[i|j]+w[j*m/k]*f[m|i|j]*-1;
        inc(0, i, n) a[i]=f[i];
    }
}

int main()
{
    using namespace fft;
    ll i, j, k; io&a&b; com swer;
    for(i=0; a[i].r>=0; ++i);
    for(j=0; b[j].r>=0; ++j);
    a[i]=b[j]=0;
    inc(0, k, i>>1) sw(a[k], a[i-k-1]);
    inc(0, k, j>>1) sw(b[k], b[j-k-1]);
    for(n=1; n<i||n<j; n<<=1);
    m=n, n<<=1, init(1), dft(a), dft(b);
    inc(0, i, n) a[i]*=b[i];
    init(-1), dft(a);
    inc(0, i, n) c[i]=(ll)floor(a[i].r/n+eps);
    inc(0, i, n) c[i+1]+=c[i]/10, c[i]%=10;
    for(i=n; !c[i]; --i);
    for(;~i; --i) putchar(c[i]+48);
    re_ 0;
}

改进了一下 FFT/NTT 的写法,最后变成 dft3 里面的样子,可见无论是代码复杂度还是时间复杂度都有了质的飞跃。 dft4 是一种可以参考的写法,两者的时间复杂度齐头并进,不分高低。

#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<sys/time.h>
using namespace std;
typedef long long ll;

const ll mxn=1<<24, p=998244353;

ll n, a1[mxn], a2[mxn], a3[mxn], b[mxn], c[mxn], Ome[mxn];

struct timeval x, y;

ll disPow(ll a, ll b)
{
    static ll r;
    for(r=1; b; a=a*a%p, b>>=1)
        b&1 ?r=r*a%p :0;
    return r;
}

char setOme(ll Ome1)
{static ll i; for(Ome[0]=i=1; i<n; ++i) Ome[i]=Ome[i-1]*Ome1%p;}

inline char dft1(ll* a)
{
    static ll i, j, k, *f, *g;
    f=b, g=c;
    for(i=0; i<n; ++i) f[i]=a[i];
    gettimeofday(&x, 0);
    for(k=1; k<n; k<<=1, swap(f, g))//考虑到k次原根
        for(i=0; i<n>>1; i+=k)//枚举子多项式
            for(j=0; j<k; ++j)//代入k次原根的j次幂
                g[i<<1 |   j]=(f[i | j]+Ome[j*n/k>>1]*f[n>>1|i | j])%p,
                g[i<<1 | k|j]=(f[i | j]-Ome[j*n/k>>1]*f[n>>1|i | j])%p;
    gettimeofday(&y, 0);
    printf("%I64d\n", (y.tv_sec-x.tv_sec)*1000000ll+y.tv_usec-x.tv_usec);
    for(i=0; i<n; ++i) a[i]=f[i]<0? f[i]+p: f[i];
}

inline char dft2(ll *a)
{
    static ll i, j, k, *f, *g;
    f=b, g=c;
    for(i=0; i<n; ++i) f[i]=a[i];
    gettimeofday(&x, 0);
    for(k=1; k<n; k<<=1, swap(f, g))//~=划分为n/k个多项式
        for(i=0; i<n; i+=n/k)//代入k次原根的i/(n/k)次幂=代入n次原根的i次幂
            for(j=0; j<n/k>>1; ++j)//枚举第j个子多项式
                g[     i>>1 | j]=(f[i | j]+Ome[i>>1]*f[i | n/k>>1|j])%p,
                g[n>>1|i>>1 | j]=(f[i | j]-Ome[i>>1]*f[i | n/k>>1|j])%p;
    gettimeofday(&y, 0);
    printf("%I64d\n", (y.tv_sec-x.tv_sec)*1000000ll+y.tv_usec-x.tv_usec);
    for(i=0; i<n; ++i) a[i]=f[i]<0? f[i]+p: f[i];
}

inline char dft3(ll *a)
{
    static ll i, j, k, *f, *g, *g1, *g2, *f1, *f2, Ome1;
    f=b, g=c;
    for(i=0; i<n; ++i) f[i]=a[i];
    gettimeofday(&x, 0);
    for(k=n; k>1; k>>=1, swap(f, g))//划分为k个多项式(f)
        for(i=0; i<n; i+=k)//~(f)
        {
            g1=g+(i>>1), g2=g+(n>>1|i>>1), f1=f+i, f2=f+(i|k>>1), Ome1=Ome[i>>1];
            for(j=0; j<k>>1; ++j)
                g1[j]=(f1[j]+Ome1*f2[j])%p,
                g2[j]=(f1[j]-Ome1*f2[j])%p;
        }
    gettimeofday(&y, 0);
    printf("%I64d\n", (y.tv_sec-x.tv_sec)*1000000ll+y.tv_usec-x.tv_usec);
    for(i=0; i<n; ++i) a[i]=f[i]<0? f[i]+p: f[i];
}

inline char dft4(ll *a)
{
    static ll i, j, k, *f, *g, *g1, *g2, *f1, *f2, Ome1;
    f=b, g=c;
    for(i=0; i<n; ++i) f[i]=a[i];
    gettimeofday(&x, 0);
    for(k=n; k>1; k>>=1, swap(f, g))//划分为k个多项式
        for(i=0; i<n; i+=k)//~
        {
            g1=g+(i>>1), g2=g+(n>>1|i>>1), f1=f+i, f2=f+(i|k>>1), Ome1=Ome[i>>1];
            for(j=0; j<k>>1; ++j, ++g1, ++g2, ++f1, ++f2)
                *g1=(*f1+Ome1**f2)%p,
                *g2=(*f1-Ome1**f2)%p;
        }
    gettimeofday(&y, 0);
    printf("%I64d\n", (y.tv_sec-x.tv_sec)*1000000ll+y.tv_usec-x.tv_usec);
    for(i=0; i<n; ++i) a[i]=f[i]<0? f[i]+p: f[i];
}

int main()
{
    ll i;
    scanf("%I64d", &n);
    for(i=0; i<n; ++i) scanf("%I64d", a1+i);
    memcpy(a2, a1, sizeof(a1));
    setOme(disPow(3, (p-1)/n));
    dft1(a1), memcpy(a1, a2, sizeof(a1));
    dft2(a1), memcpy(a1, a2, sizeof(a1));
    dft3(a1), memcpy(a1, a2, sizeof(a1));
    dft4(a1), memcpy(a1, a2, sizeof(a1));
    return 0;
}

下面是四种写法在 n=220 时运行的微秒时间:
【51nod 1028】 大数乘法 V2 【FFT/NTT】