P3181 [HAOI2016]找相同字符

时间:2022-05-02 22:18:40

思路

广义SAM

把两个字符串建成广义SAM,然后统计两个SAM中相同节点的endpos大小乘积即可

记得开long long

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
using namespace std;
const int MAXN = 800100;
int endpos[MAXN][2],trans[MAXN][26],suflink[MAXN],maxlen[MAXN],minlen[MAXN],Nodecnt,in[MAXN],n;
char s[MAXN];
int New_state(int _maxlen,int _minlen,int *_trans,int _suflink){
++Nodecnt;
maxlen[Nodecnt]=_maxlen;
minlen[Nodecnt]=_minlen;
if(_trans)
for(int i=0;i<26;i++)
trans[Nodecnt][i]=_trans[i];
suflink[Nodecnt]=_suflink;
return Nodecnt;
}
int add_len(int u,int c,int inq){
if(trans[u][c]){
int v=trans[u][c];
if(maxlen[v]==maxlen[u]+1){
endpos[v][inq]++;
return v;
}
else{
int y=New_state(maxlen[u]+1,0,trans[v],suflink[v]);
endpos[y][inq]++;
suflink[v]=y;
minlen[v]=maxlen[y]+1;
while(u&&(trans[u][c]==v)){
trans[u][c]=y;
u=suflink[u];
}
minlen[y]=maxlen[suflink[y]]+1;
return y;
}
}
else{
int z=New_state(maxlen[u]+1,0,NULL,0);
endpos[z][inq]++;
while(u&&(trans[u][c]==0)){
trans[u][c]=z;
u=suflink[u];
}
if(!u){
suflink[z]=1;
minlen[z]=1;
return z;
}
int v=trans[u][c];
if(maxlen[v]==maxlen[u]+1){
suflink[z]=v;
minlen[z]=maxlen[v]+1;
return z;
}
int y=New_state(maxlen[u]+1,0,trans[v],suflink[v]);
suflink[v]=suflink[z]=y;
minlen[v]=minlen[z]=maxlen[y]+1;
while(u&&(trans[u][c]==v)){
trans[u][c]=y;
u=suflink[u];
}
minlen[y]=maxlen[suflink[y]]+1;
return z;
}
}
queue<int> q;
void get_sz(void){
for(int i=2;i<=Nodecnt;i++)
in[suflink[i]]++;
for(int i=0;i<=Nodecnt;i++)
if(!in[i])
q.push(i);
while(!q.empty()){
int x=q.front();
q.pop();
endpos[suflink[x]][0]+=endpos[x][0];
endpos[suflink[x]][1]+=endpos[x][1];
in[suflink[x]]--;
if(!in[suflink[x]])
q.push(suflink[x]);
}
}
long long ans=0;
int main(){
scanf("%s",s+1);
n=strlen(s+1);
Nodecnt=1;
int last=1;
for(int i=1;i<=n;i++)
last=add_len(last,s[i]-'a',0);
scanf("%s",s+1);
n=strlen(s+1);
last=1;
for(int i=1;i<=n;i++)
last=add_len(last,s[i]-'a',1);
get_sz();
for(int i=2;i<=Nodecnt;i++){
ans+=(long long)((long long)maxlen[i]-minlen[i]+1)*(long long)((long long)endpos[i][0]*endpos[i][1]);
}
printf("%lld\n",ans);
return 0;
}