hdu 5730 Shell Necklace —— 分治FFT

时间:2021-06-29 21:47:41

题目:http://acm.hdu.edu.cn/showproblem.php?pid=5730

DP式:\( f[i] = \sum\limits_{j=1}^{i} f[i-j] * a[j] \)

因为没有给 \( f[0] \) 赋初值,所以在递归底层令 \( f[l] += a[l] \)

注意多组数据清空数组;

读入 \( s[i] \) 时要取模!!

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
typedef double db;
int const xn=(1<<18),mod=313;
db const Pi=acos(-1.0);
int n,rev[xn],f[xn],s[xn];
struct cpl{
  db x,y;
  cpl(db xx=0,db yy=0):x(xx),y(yy) {}
}a[xn],b[xn];
cpl operator + (cpl a,cpl b){return cpl(a.x+b.x,a.y+b.y);}
cpl operator - (cpl a,cpl b){return cpl(a.x-b.x,a.y-b.y);}
cpl operator * (cpl a,cpl b){return cpl(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int rd()
{
  int ret=0,f=1; char ch=getchar();
  while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return f?ret:-ret;
}
void fft(cpl *a,int tp,int lim)
{
  for(int i=0;i<lim;i++)
    if(i<rev[i])swap(a[i],a[rev[i]]);
  for(int mid=1;mid<lim;mid<<=1)
    {
      cpl wn=cpl(cos(Pi/mid),tp*sin(Pi/mid));
      for(int j=0,len=(mid<<1);j<lim;j+=len)
    {
      cpl w=cpl(1,0);
      for(int k=0;k<mid;k++,w=w*wn)
        {
          cpl x=a[j+k],y=w*a[j+mid+k];
          a[j+k]=x+y; a[j+mid+k]=x-y;
        }
    }
    }
  if(tp==1)return; 
  for(int i=0;i<lim;i++)a[i].x=a[i].x/lim;
}
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
void work(int l,int r)
{
  if(l==r){f[l]=upt(f[l]+s[l]); return;}//f[0]=0...
  int len=r-l+1,mid=((l+r)>>1);
  work(l,mid);
  int lim=1,L=0;
  while(lim<len)lim<<=1,L++;
  for(int i=0;i<lim;i++)rev[i]=((rev[i>>1]>>1)|((i&1)<<(L-1)));
  for(int i=l;i<=mid;i++)a[i-l].x=f[i],a[i-l].y=0;//y
  for(int i=mid-l+1;i<lim;i++)a[i].x=0,a[i].y=0;
  for(int i=0;i<lim;i++)b[i].x=s[i],b[i].y=0;
  fft(a,1,lim); fft(b,1,lim);
  for(int i=0;i<lim;i++)a[i]=a[i]*b[i];
  fft(a,-1,lim);
  for(int i=mid+1;i<=r;i++)f[i]+=(ll)(a[i-l].x+0.5)%mod;
  work(mid+1,r);
}
int main()
{
  while(1)
    {
      n=rd(); if(!n)return 0;
      memset(f,0,sizeof f); memset(s,0,sizeof s);
      for(int i=1;i<=n;i++)s[i]=rd()%mod;//%mod!!
      work(1,n);
      printf("%d\n",f[n]);
    }
}