POJ 3415 Common Substrings ——后缀数组

时间:2021-07-27 13:59:32

【题目分析】

判断有多少个长度不小于k的相同子串的数目。

N^2显然是可以做到的。

其实可以维护一个关于height的单调栈,统计一下贡献,就可以了。

其实还是挺难写的OTZ。

【代码】

#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>

#include <map>
#include <set>
#include <queue>
#include <string>
#include <iostream>
#include <algorithm>

using namespace std;

#define maxn 300005
#define LL long long
#define inf 0x3f3f3f3f
#define F(i,j,k) for (LL i=j;i<=k;++i)
#define D(i,j,k) for (LL i=j;i>=k;--i)

void Finout()
{
    #ifndef ONLINE_JUDGE
    freopen("in.txt","r",stdin);
    freopen("wa.txt","w",stdout);
    #endif
}

LL Getint()
{
    LL x=0,f=1; char ch=getchar();
    while (ch<'0'||ch>'9') {if (ch=='-') f=-1; ch=getchar();}
    while (ch>='0'&&ch<='9') {x=x*10+ch-'0'; ch=getchar();}
    return x*f;
}

char ss[maxn];
LL n,l1,l2,k;

struct SuffixArray{
	LL s[maxn];
	LL rk[maxn],h[maxn],cnt[maxn],tmp[maxn],sa[maxn];
	void init()
	{
		memset(s,0,sizeof s);
//		memset(rk,0,sizeof rk);
//		memset(h,0,sizeof h);
//		memset(cnt,0,sizeof cnt);
//		memset(tmp,0,sizeof tmp);
//		memset(sa,0,sizeof sa);
	}
	void build(LL n,LL m)
	{
		LL i,j,k; n++;
		F(i,0,2*n+5) rk[i]=h[i]=tmp[i]=sa[i]=0;
		F(i,0,m-1) cnt[i]=0;
		F(i,0,n-1) cnt[rk[i]=s[i]]++;
		F(i,1,m-1) cnt[i]+=cnt[i-1];
		F(i,0,n-1) sa[--cnt[rk[i]]]=i;
		for (k=1;k<=n;k<<=1)
		{
			F(i,0,n-1)
			{
				j=sa[i]-k;
				if (j<0) j+=n;
				tmp[cnt[rk[j]]++]=j;
			}
			sa[tmp[cnt[0]=0]]=j=0;
			F(i,1,n-1)
			{
				if (rk[tmp[i]]!=rk[tmp[i-1]]||rk[tmp[i]+k]!=rk[tmp[i-1]+k]) cnt[++j]=i;
				sa[tmp[i]]=j;
			}
			memcpy(rk,sa,n*sizeof(LL));
			memcpy(sa,tmp,n*sizeof(LL));
			if (j>=n-1) break;
		}
		for (j=rk[h[i=k=0]=0];i<n-1;++i,++k)
			while (~k&&s[i]!=s[sa[j-1]+k]) h[j]=k--,j=rk[sa[j]+1];
		//Debug
		/*
		F(i,0,n-1) cout<<s[i]<<" "; cout<<endl;
		F(i,0,n-1) cout<<sa[i]<<" ";cout<<endl;
		F(i,0,n-1) cout<<h[i]<<" ";cout<<endl;
		*/
		//Debug over
	}
	LL sta[maxn][2],top;
	LL tot,sum;
	void solve(LL n,LL k)
	{
//		n++;
		top=0;sum=0;tot=0;
		F(i,1,n)
		{
			if (h[i]<k) top=tot=0;
			else
			{
				LL cnt=0;
				if (sa[i-1]<l1) cnt++,tot+=h[i]-k+1;
				while (top>0&&h[i]<=sta[top-1][0])
				{
					top--;
					tot-=sta[top][1]*(sta[top][0]-h[i]);
					cnt+=sta[top][1];
				}
				sta[top][0]=h[i]; sta[top++][1]=cnt;
				if (sa[i]>l1) sum+=tot;
			}
		}
		top=tot=0;
		F(i,1,n)
		{
			if (h[i]<k) top=tot=0;
			else
			{
				LL cnt=0;
				if (sa[i-1]>l1) cnt++,tot+=h[i]-k+1;
				while (top>0&&h[i]<=sta[top-1][0])
				{
					top--;
					tot-=sta[top][1]*(sta[top][0]-h[i]);
					cnt+=sta[top][1];
				}
				sta[top][0]=h[i]; sta[top++][1]=cnt;
				if (sa[i]<l1) sum+=tot;
			}
		}
		cout<<sum<<endl;
	}
}arr;

int main()
{
    Finout();
    while (scanf("%lld",&k)!=EOF&&k)
    {
    	arr.init();
    	memset(ss,0,sizeof ss);
    	scanf("%s",ss);l1=strlen(ss);//cout<<l1<<endl;
    	F(i,0,l1-1) arr.s[i]=ss[i];
    	arr.s[l1]=248;
    	memset(ss,0,sizeof ss);
    	scanf("%s",ss);l2=strlen(ss);//cout<<l2<<endl;
    	F(i,0,l2-1) arr.s[l1+i+1]=ss[i];
    	arr.build(l1+l2+1,250);
    	arr.solve(l1+l2+1,k);
	}
}