[POJ3415]Common Substrings(后缀数组+单调栈)

时间:2022-09-13 16:06:14

题目描述

传送门
题意:给定两个字符串 A 和 B ,求长度不小于 k 的公共子串的个数(可以相同)。

题解

首先把一个串接在另一个串的后面,中间放一个没出现过的字符。
由于每一个子串都是某一个后缀的前缀,求出sa和height了之后,我们可以将height分组,组内都是height>=k的后缀。可以知道长度不小于k的公共子串是两个后缀的前缀,并且它们一定在同一组内
那么对于每一组,从前往后扫,假设遇到了A的后缀,那么统计一下它前面与B的后缀能组成多少>=k的公共子串,然后再把A和B反一下,这样扫两遍,就求出了答案。
关键是怎么统计。暴力统计是 O(n2) 的肯定不行。我们知道两个后缀的最长公共前缀是它们的区间最小值,所以可以维护一个自底向上单调递增的栈。同时需要维护的是,栈中的总和已经栈中每一个元素的出现次数。需要注意的是,只有需要统计的一种后缀的height是有价值的,不需要统计的一种后缀的height需要入栈,但是没有价值,这里的价值也就是上文说道的“出现次数”。每一次弹栈,相当于是用较小的height替换了较大的height,但是总个数不能改变。

代码

#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
#define LL long long
#define N 200005

int n,m,k,ls,la;
char a[N],s[N];
int *x,*y,X[N],Y[N],c[N],sa[N],height[N],rank[N];
int stack[N],cnt[N],top;
LL sum,ans;

void clear()
{
    n=m=ls=la=top=0;sum=ans=0LL;
    memset(X,0,sizeof(X));memset(Y,0,sizeof(Y));memset(c,0,sizeof(c));
    memset(sa,0,sizeof(sa));memset(height,0,sizeof(height));memset(rank,0,sizeof(rank));
}
void build_sa()
{
    m=200;
    x=X,y=Y;
    for (int i=0;i<m;++i) c[i]=0;
    for (int i=0;i<n;++i) ++c[x[i]=s[i]];
    for (int i=1;i<m;++i) c[i]+=c[i-1];
    for (int i=n-1;i>=0;--i) sa[--c[x[i]]]=i;

    for (int k=1;k<=n;k<<=1)
    {
        int p=0;
        for (int i=n-k;i<n;++i) y[p++]=i;
        for (int i=0;i<n;++i) if (sa[i]>=k) y[p++]=sa[i]-k;

        for (int i=0;i<m;++i) c[i]=0;
        for (int i=0;i<n;++i) ++c[x[y[i]]];
        for (int i=1;i<m;++i) c[i]+=c[i-1];
        for (int i=n-1;i>=0;--i) sa[--c[x[y[i]]]]=y[i];

        swap(x,y);
        p=1,x[sa[0]]=0;
        for (int i=1;i<n;++i)
            x[sa[i]]=y[sa[i-1]]==y[sa[i]]&&((sa[i-1]+k<n?y[sa[i-1]+k]:-1)==(sa[i]+k<n?y[sa[i]+k]:-1))?p-1:p++;
        if (p>n) break;
        m=p;
    }
}
void build_height()
{
    for (int i=0;i<n;++i) rank[sa[i]]=i;
    int k=0;height[0]=0;
    for (int i=0;i<n;++i)
    {
        if (!rank[i]) continue;
        if (k) --k;
        int j=sa[rank[i]-1];
        while (i+k<n&&j+k<n&&s[i+k]==s[j+k]) ++k;
        height[rank[i]]=k;
    }
}
int main()
{
    while (~scanf("%d\n",&k))
    {
        if (!k) break;
        clear();
        gets(s);gets(a);ls=strlen(s);la=strlen(a);
        s[ls]='$';
        for (int i=0;i<la;++i) s[ls+i+1]=a[i];
        n=ls+la+1;
        build_sa();
        build_height();

        for (int i=0;i<n;++i)
        {
            if (height[i]<k)
            {
                top=sum=0;
                continue;
            }
            int num=0;
            while (top&&stack[top]>height[i])
            {
                sum-=(LL)((stack[top]-k+1)*cnt[top]);
                sum+=(LL)((height[i]-k+1)*cnt[top]);
                num+=cnt[top];--top;
            }
            stack[++top]=height[i];
            if (sa[i-1]>ls)
            {
                sum+=(LL)(height[i]-k+1);
                cnt[top]=num+1;
            }
            else cnt[top]=num;
            if (sa[i]<ls)
                ans+=sum;
        }
        for (int i=0;i<n;++i)
        {
            if (height[i]<k)
            {
                top=sum=0;
                continue;
            }
            int num=0;
            while (top&&stack[top]>height[i])
            {
                sum-=(LL)((stack[top]-k+1)*cnt[top]);
                sum+=(LL)((height[i]-k+1)*cnt[top]);
                num+=cnt[top];--top;
            }
            stack[++top]=height[i];
            if (sa[i-1]<ls)
            {
                sum+=(LL)(height[i]-k+1);
                cnt[top]=num+1;
            }
            else cnt[top]=num;
            if (sa[i]>ls)
                ans+=sum;
        }
        printf("%lld\n",ans);
    }
}