洛谷P4199 万径人踪灭(manacher+FFT)

时间:2022-05-11 12:08:14

传送门

题目所求为所有的不连续回文子序列个数,可以转化为回文子序列数-回文子串数

回文子串manacher跑一跑就行了,考虑怎么求回文子序列数

我们考虑,如果$S_i$是回文子序列的对称中心,那么只要$S_{i-j}$和$S_{i+j}$相等,我们就多了一种选择

设共有$x$组相等的,那么以$S_i$为对称中心的回文子序列个数就是$2^{x+1}-1$,表示这$x$组包括对称中心选或不选,除去全都不选的都能算入答案

然而对称中心不一定在字符上可能在两个字符中间,那么这种时候回文子序列数就是$2^x-1$(因为没有中间的字符所以无所谓选不选)

然后考虑如何计算每一个位置上的$x$

我们考虑构造多项式$A$,原串上为$a$的位置设为$1$,$b$的位置设为$0$,如果$s[i]==s[j]$,那么他们的贡献会加到$(i+j)/2$上

然后发现这玩意儿和卷积很像,于是我们把除以二去掉,那么每一对$s[i]==s[j]$都会把贡献加到$i+j$上,所以只要把$A$自乘一下就可以了

然后构造$B$,原串上为$b$的位置设为$0$,也自乘一下就行了

然后把$A$和$B$对应系数加起来,再减去((i&1)^1)(表示这一位对称中心是否在字符上),再求一下$2$的多少次幂减一,最后减掉回文串个数就行了

 //minamoto
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define ll long long
#define add(x,y) (x+y>=P?x+y-P:x+y)
using namespace std;
const int N=5e5+,P=1e9+;const double Pi=acos(-1.0);
struct complex{
double x,y;
complex(double xx=,double yy=){x=xx,y=yy;}
inline complex operator +(complex b){return complex(x+b.x,y+b.y);}
inline complex operator -(complex b){return complex(x-b.x,y-b.y);}
inline complex operator *(complex b){return complex(x*b.x-y*b.y,x*b.y+y*b.x);}
}A[N],B[N];
int n,m,l,r[N],limit=;
void FFT(complex *A,int type){
for(int i=;i<limit;++i)
if(i<r[i]) swap(A[i],A[r[i]]);
for(int mid=;mid<limit;mid<<=){
complex Wn(cos(Pi/mid),type*sin(Pi/mid));
for(int R=mid<<,j=;j<limit;j+=R){
complex w(,);
for(int k=;k<mid;++k,w=w*Wn){
complex x=A[j+k],y=w*A[j+k+mid];
A[j+k]=x+y,A[j+k+mid]=x-y;
}
}
}
}
ll s[N],ans[N],x[N],p[N],sum;char ss[N];
void manacher(){
int mx=,id;
for(int i=,l=(n<<)+;i<=l;++i){
p[i]=i<mx?min(p[*id-i],(ll)mx-i):1ll;
while(x[i-p[i]]==x[i+p[i]]) ++p[i];
if(mx<i+p[i]) mx=i+p[i],id=i;
}
}
ll ksm(ll a,ll b){
ll res=;
while(b){
if(b&) res=res*a%P;
a=a*a%P,b>>=;
}
return res;
}
int main(){
// freopen("testdata.in","r",stdin);
scanf("%s",ss+);n=strlen(ss+);
for(int i=;i<=n;++i) s[i]=ss[i]=='a';
for(int i=;i<=n;++i)
x[(i<<)-]=,x[i<<]=s[i];
x[]=-,x[(n+)<<]=-,x[(n<<)+]=;
while(limit<=(n<<)+) limit<<=,++l;
for(int i=;i<limit;++i)
r[i]=(r[i>>]>>)|((i&)<<(l-));
for(int i=;i<=n;++i) A[i].x=B[i].x=s[i];
FFT(A,),FFT(B,);
for(int i=;i<limit;++i) A[i]=A[i]*B[i];
FFT(A,-);
for(int i=,l=(n<<)+;i<=l;++i) ans[i]+=((ll)(A[i].x/limit+0.5)-((i&)^));
memset(A,,sizeof(A)),memset(B,,sizeof(B));
for(int i=;i<=n;++i) A[i].x=B[i].x=(s[i]^);
FFT(A,),FFT(B,);
for(int i=;i<limit;++i) A[i]=A[i]*B[i];
FFT(A,-);
for(int i=,l=(n<<)+;i<=l;++i) ans[i]+=((ll)(A[i].x/limit+0.5)-((i&)^));
for(int i=,l=(n<<)+;i<=l;++i) ans[i]=((ans[i]+((i&)^))>>)+((i&)^);
for(int i=,l=(n<<)+;i<=l;++i) ans[i]=ksm(,ans[i])-;
manacher();
for(int i=,l=(n<<)+;i<=l;++i) ans[i]-=(p[i]>>);
for(int i=,l=(n<<)+;i<=l;++i) sum=add(sum,ans[i]);
printf("%lld\n",sum);
return ;
}