题目描述
传送门
题意:给定两个字符串 A 和 B ,求长度不小于 k 的公共子串的个数(可以相同)。
题解
首先把一个串接在另一个串的后面,中间放一个没出现过的字符。
由于每一个子串都是某一个后缀的前缀,求出sa和height了之后,我们可以将height分组,组内都是height>=k的后缀。可以知道长度不小于k的公共子串是两个后缀的前缀,并且它们一定在同一组内。
那么对于每一组,从前往后扫,假设遇到了A的后缀,那么统计一下它前面与B的后缀能组成多少>=k的公共子串,然后再把A和B反一下,这样扫两遍,就求出了答案。
关键是怎么统计。暴力统计是
代码
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
#define LL long long
#define N 200005
int n,m,k,ls,la;
char a[N],s[N];
int *x,*y,X[N],Y[N],c[N],sa[N],height[N],rank[N];
int stack[N],cnt[N],top;
LL sum,ans;
void clear()
{
n=m=ls=la=top=0;sum=ans=0LL;
memset(X,0,sizeof(X));memset(Y,0,sizeof(Y));memset(c,0,sizeof(c));
memset(sa,0,sizeof(sa));memset(height,0,sizeof(height));memset(rank,0,sizeof(rank));
}
void build_sa()
{
m=200;
x=X,y=Y;
for (int i=0;i<m;++i) c[i]=0;
for (int i=0;i<n;++i) ++c[x[i]=s[i]];
for (int i=1;i<m;++i) c[i]+=c[i-1];
for (int i=n-1;i>=0;--i) sa[--c[x[i]]]=i;
for (int k=1;k<=n;k<<=1)
{
int p=0;
for (int i=n-k;i<n;++i) y[p++]=i;
for (int i=0;i<n;++i) if (sa[i]>=k) y[p++]=sa[i]-k;
for (int i=0;i<m;++i) c[i]=0;
for (int i=0;i<n;++i) ++c[x[y[i]]];
for (int i=1;i<m;++i) c[i]+=c[i-1];
for (int i=n-1;i>=0;--i) sa[--c[x[y[i]]]]=y[i];
swap(x,y);
p=1,x[sa[0]]=0;
for (int i=1;i<n;++i)
x[sa[i]]=y[sa[i-1]]==y[sa[i]]&&((sa[i-1]+k<n?y[sa[i-1]+k]:-1)==(sa[i]+k<n?y[sa[i]+k]:-1))?p-1:p++;
if (p>n) break;
m=p;
}
}
void build_height()
{
for (int i=0;i<n;++i) rank[sa[i]]=i;
int k=0;height[0]=0;
for (int i=0;i<n;++i)
{
if (!rank[i]) continue;
if (k) --k;
int j=sa[rank[i]-1];
while (i+k<n&&j+k<n&&s[i+k]==s[j+k]) ++k;
height[rank[i]]=k;
}
}
int main()
{
while (~scanf("%d\n",&k))
{
if (!k) break;
clear();
gets(s);gets(a);ls=strlen(s);la=strlen(a);
s[ls]='$';
for (int i=0;i<la;++i) s[ls+i+1]=a[i];
n=ls+la+1;
build_sa();
build_height();
for (int i=0;i<n;++i)
{
if (height[i]<k)
{
top=sum=0;
continue;
}
int num=0;
while (top&&stack[top]>height[i])
{
sum-=(LL)((stack[top]-k+1)*cnt[top]);
sum+=(LL)((height[i]-k+1)*cnt[top]);
num+=cnt[top];--top;
}
stack[++top]=height[i];
if (sa[i-1]>ls)
{
sum+=(LL)(height[i]-k+1);
cnt[top]=num+1;
}
else cnt[top]=num;
if (sa[i]<ls)
ans+=sum;
}
for (int i=0;i<n;++i)
{
if (height[i]<k)
{
top=sum=0;
continue;
}
int num=0;
while (top&&stack[top]>height[i])
{
sum-=(LL)((stack[top]-k+1)*cnt[top]);
sum+=(LL)((height[i]-k+1)*cnt[top]);
num+=cnt[top];--top;
}
stack[++top]=height[i];
if (sa[i-1]<ls)
{
sum+=(LL)(height[i]-k+1);
cnt[top]=num+1;
}
else cnt[top]=num;
if (sa[i]>ls)
ans+=sum;
}
printf("%lld\n",ans);
}
}