Codeforces 954I Yet Another String Matching Problem FFT

时间:2022-12-22 10:20:45

题意

有两个字符串,每次可以进行如下操作:选择两个字符 c 1 , c 2 ,然后把两个字符串中所有的 c 1 替换成 c 2 。定义两个长度相等的字符串的距离为,最少需要多少次操作使得两个字符串相等。现在给出长度为n的字符串S和长度为m的字符串T。问对于S中所有长度为m的子串,和T的距离是多少。
n , m 125000 ,字符集大小为 6

分析

先考虑如何求两个字符串的距离。
对于一个位置相等的字符对 ( c 1 , c 2 ) ,我们在 c 1 c 2 之间连一条边,那么答案就是 6 。因为连通块之间肯定互不影响,且一个大小为 n 的连通块必然可以通过 n 1 次操作来满足条件。
接下来我们就枚举两种字符,看它们在哪些位置有连边,然后FFT一下就做完了。

代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>

const int N=125005;
const double pi=acos(-1);

int n,m,L,rev[N*4],f[N][6];
char s[N],t[N];
struct com
{
    double x,y;

    com operator + (const com &d) const {return (com){x+d.x,y+d.y};}
    com operator - (const com &d) const {return (com){x-d.x,y-d.y};}
    com operator * (const com &d) const {return (com){x*d.x-y*d.y,x*d.y+y*d.x};}
    com operator / (const double &d) const {return (com){x/d,y/d};}
}a[N*4],b[N*4];

int find(int x,int y)
{
    if (f[x][y]==y) return y;
    else return f[x][y]=find(x,f[x][y]);
}

void FFT(com *a,int f)
{
    for (int i=0;i<L;i++) if (i<rev[i]) std::swap(a[i],a[rev[i]]);
    for (int i=1;i<L;i<<=1)
    {
        com wn=(com){cos(pi/i),f*sin(pi/i)};
        for (int j=0;j<L;j+=(i<<1))
        {
            com w=(com){1,0};
            for (int k=0;k<i;k++)
            {
                com u=a[j+k],v=a[j+k+i]*w;
                a[j+k]=u+v;a[j+k+i]=u-v;
                w=w*wn;
            }
        }
    }
    if (f==-1) for (int i=0;i<L;i++) a[i]=a[i]/L;
}

int main()
{
    scanf("%s%s",s+1,t+1);
    n=strlen(s+1);
    m=strlen(t+1);
    int lg=0;
    for (L=1;L<=n*2;L<<=1,lg++);
    for (int i=0;i<L;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
    for (int i=1;i<=n-m+1;i++)
        for (int j=0;j<=5;j++)
            f[i][j]=j;
    for (int i=0;i<=5;i++)
        for (int j=0;j<=5;j++)
        {
            if (i==j) continue;
            for (int k=0;k<L;k++) a[k]=b[k]=(com){0,0};
            for (int k=1;k<=n;k++) if (s[k]-'a'==i) a[k]=(com){1,0};
            for (int k=1;k<=m;k++) if (t[k]-'a'==j) b[m-k+1]=(com){1,0};
            FFT(a,1);FFT(b,1);
            for (int k=0;k<L;k++) a[k]=a[k]*b[k];
            FFT(a,-1);
            for (int k=1;k<=n-m+1;k++)
                if ((int)(a[k+m].x+0.5)>0)
                {
                    int x=find(k,i),y=find(k,j);
                    if (x!=y) f[k][x]=y;
                }
        }
    for (int i=1;i<=n-m+1;i++)
    {
        int ans=0;
        for (int j=0;j<=5;j++) ans+=(f[i][j]!=j);
        printf("%d ",ans);
    }
    return 0;
}