HDU2222 Keywords Search(AC自动机模板)

时间:2022-01-14 09:34:30

AC自动机是一种多模式匹配的算法。大概过程如下:

  • 首先所有模式串构造一棵Trie树,Trie树上的每个非根结点都代表一个从根出发到该点路径的字符串。
  • 然后每个结点都计算出其fail指针的值,这个fail指针就指向这个结点所表示字符串的最长存在的后缀所对应的结点,如果不存在就指向根:计算每个结点的fail用BFS,比如当前结点u出队要拓展并计算其孩子结点的fail,v是其第k个孩子,fail[v]的值就是某个fail[fail[fail...[u]]]存在第k孩子结点其第k个孩子结点,如果不存在fail[v]就等于root。
  • 最后主串就往Trie树上跑,在某个Trie树结点失配了就跳转到这个结点fail指针所指的结点继续跑——不过如果匹配了某个模式串这时可能某个模式串的后缀串被忽略了,所以需要用到temp指针,去检查是否有遗漏后缀没匹配。

而这题大概就是给几个模式串,一个主串,问有几个模式串被主串匹配。

AC自动机的模板题。有个可以优化的地方就是某个模式串被匹配了,下一次经过这儿就可以跳过了temp指针的过程了。

代码参考自kuangbin巨的博客,太简洁了(300+ms):

 #include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
int tn,ch[][],cnt[],fail[];
void insert(char *s){
int x=;
for(int i=; s[i]; ++i){
int y=s[i]-'a';
if(ch[x][y]==) ch[x][y]=++tn;
x=ch[x][y];
}
++cnt[x];
}
void init(){
memset(fail,,sizeof(fail));
queue<int> que;
for(int i=; i<; ++i){
if(ch[][i]) que.push(ch[][i]);
}
while(!que.empty()){
int x=que.front(); que.pop();
for(int i=;i<;++i){
if(ch[x][i]) que.push(ch[x][i]),fail[ch[x][i]]=ch[fail[x]][i];
else ch[x][i]=ch[fail[x]][i];
}
}
}
int query(char *s){
int x=,res=;
for(int i=; s[i]; ++i){
int tmp=x=ch[x][s[i]-'a'];
while(tmp){
if(cnt[tmp]>=){
res+=cnt[tmp];
cnt[tmp]=-;
}else break;
tmp=fail[tmp];
}
}
return res;
}
char S[],T[];
int main(){
int t,n;
scanf("%d",&t);
while(t--){
tn=;
memset(ch,,sizeof(ch));
memset(cnt,,sizeof(cnt));
scanf("%d",&n);
while(n--){
scanf("%s",T);
insert(T);
}
init();
scanf("%s",S);
printf("%d\n",query(S));
}
return ;
}

另外之前学的指针版本的,指针版本跑得更快(200+ms):

 #include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
typedef struct Node *pNode;
struct Node{
int cnt;
pNode fail,nxt[];
Node(){
cnt=; fail=NULL;
for(int i=;i<;++i) nxt[i]=NULL;
}
};
pNode root;
char S[];
void insert(char *s){
pNode p=root;
for(int i=;s[i];++i){
int index=s[i]-'a';
if(p->nxt[index]==NULL){
p->nxt[index]=new Node;
}
p=p->nxt[index];
}
++p->cnt;
}
void init(){
queue<pNode> que;
que.push(root);
while(que.size()){
pNode y=que.front(); que.pop();
for(int i=;i<;++i){
if(y->nxt[i]==NULL) continue;
if(y==root){
y->nxt[i]->fail=root;
que.push(y->nxt[i]);
continue;
}
pNode x=y->fail;
while(x&&x->nxt[i]==NULL) x=x->fail;
if(x==NULL) y->nxt[i]->fail=root;
else y->nxt[i]->fail=x->nxt[i];
que.push(y->nxt[i]);
}
}
}
int query(){
int res=;
pNode x=root;
for(int i=;S[i];++i){
int index=S[i]-'a';
while(x->nxt[index]==NULL&&x!=root) x=x->fail;
x=x->nxt[index];
if(x==NULL) x=root;
pNode y=x;
while(y!=root){
if(y->cnt>=){
res+=y->cnt;
y->cnt=-;
}else break;
y=y->fail;
}
}
return res;
}
int main(){
int t,n;
char s[];
scanf("%d",&t);
while(t--){
root=new Node;
scanf("%d",&n);
for(int i=;i<n;++i){
scanf("%s",s);
insert(s);
}
scanf("%s",S);
init();
printf("%d\n",query());
}
return ;
}