首先有一个想法,翻转串后直接卷积看有没有0匹配上1。但这是必要而不充分的因为在原串和翻转串中?不能同时取两个值。
先有一些结论:
如果s中长度为len的前缀是border,那么其存在|s|-len的循环节(最后一段不一定完整)。
如果已知len不是s的循环节,那么显然len的因子也不是s的循环节。
如果位置差为len的两个位置无法匹配,那么len不是s的循环节。
于是可得:如果位置差为len的两个位置无法匹配,那么长度为|s|-(len的因子)的前缀不是border。
可以发现其实问号出现冲突的原因就在于此,即某两个01的位置差为len,而问号在两者之间且与其中一个的位置差是len的因子。
那么这是充分的,即只要不会被筛掉则一定是border。
那么我们卷完后只要枚举长度去看他的倍数有没有被筛掉就可以了。由调和级数,复杂度O(nlogn)。
LOJ过了,BZOJ上不出意外地T掉了。
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
int read()
{
int x=,f=;char c=getchar();
while (c<''||c>'') {if (c=='-') f=-;c=getchar();}
while (c>=''&&c<='') x=(x<<)+(x<<)+(c^),c=getchar();
return x*f;
}
#define N 1050000
#define P 998244353
#define inv3 332748118
int n,t,s[N],a[N],b[N],f[N],r[N];
long long ans;
int ksm(int a,int k)
{
if (k==) return ;
int tmp=ksm(a,k>>);
if (k&) return 1ll*tmp*tmp%P*a%P;
else return 1ll*tmp*tmp%P;
}
void DFT(int n,int *a,int p)
{
for (int i=;i<n;i++) if (i<r[i]) swap(a[i],a[r[i]]);
for (int i=;i<=n;i<<=)
{
int wn=ksm(p,(P-)/i);
for (int j=;j<n;j+=i)
{
int w=;
for (int k=j;k<j+(i>>);k++,w=1ll*w*wn%P)
{
int x=a[k],y=1ll*w*a[k+(i>>)]%P;
a[k]=(x+y)%P,a[k+(i>>)]=(x-y+P)%P;
}
}
}
}
void mul(int n)
{
DFT(n,a,),DFT(n,b,);
for (int i=;i<n;i++) a[i]=1ll*a[i]*b[i]%P;
DFT(n,a,inv3);
int inv=ksm(n,P-);
for (int i=;i<n;i++) a[i]=1ll*a[i]*inv%P;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("bzoj5372.in","r",stdin);
freopen("bzoj5372.out","w",stdout);
const char LL[]="%I64d";
#else
const char LL[]="%lld";
#endif
char c=getchar();
while (c==''||c==''||c=='?')
s[n++]=(c=='?'?-:(c^)),c=getchar();
t=;while (t<=(n<<)) t<<=;
for (int i=;i<t;i++) r[i]=(r[i>>]>>)|(i&)*(t>>);
memset(a,,sizeof(a));memset(b,,sizeof(b));
for (int i=;i<n;i++) a[i]=(s[i]==);
for (int i=;i<n;i++) b[i]=(s[n-i-]==);
mul(t);
for (int i=;i<n;i++) f[n-i-]+=(a[i]>);
memset(a,,sizeof(a));memset(b,,sizeof(b));
for (int i=;i<n;i++) a[i]=(s[i]==);
for (int i=;i<n;i++) b[i]=(s[n-i-]==);
mul(t);
for (int i=;i<n;i++) f[n-i-]+=(a[i]>);
for (int i=;i<n;i++)
{
ans^=1ll*(n-i)*(n-i);
for (int j=i;j<=n;j+=i)
if (f[j]) {ans^=1ll*(n-i)*(n-i);break;}
}
ans^=1ll*n*n;
cout<<ans;
return ;
}