POJ2778 DNA sequence

时间:2021-08-15 06:49:50

题目大意:给出m个疾病基因片段(m<=10),每个片段不超过10个字符。求长度为n的不包含任何一个疾病基因片段的DNA序列共有多少种?(n<=2000000000)

分析:本题需要对m个疾病基因片段构建一个AC自动机,自动机中的每个节点表示一个状态。其中AC自动机中的叶子节点表示的是病毒,所以是非法状态。同时,如果某个节点到根的字符串的后缀是一个病毒,那么该节点也是非法状态。剔除掉所有的非法状态,那么剩下的节点都表示合法状态了。然后用节点的nxt指针表示状态之间转化关系。若nxt[i]==0,则nxt[i]指针指向当前节点fail指针的nxt[i],如果仍然为0,则nxt[i]指向根节点。这样处理以后,每个指针都不会指向0。这样,该自动机可以看做是一个合法状态的转换图,节点表示各种合法状态,边表示添加一个字符将转换为另一个状态。于是我们可以得到一个矩阵。该矩阵实际上表示该状态图的邻接矩阵。对该矩阵自乘n次。最后结果矩阵的第1行各元素之和表示从空状态添加n个字符能够得到的所有合法状态的数量。

矩阵的思想非常巧妙。

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
#define MAXN 102
#define MAXL 12
#define MAXC 4
#define MOD 100000
struct node
{
int fail,nxt[6],flag;
}trie[MAXN];
int head,tail,myq[MAXN],root=1,tot=1;
char word[MAXL];
int degree;
int a[MAXN][MAXN],b[MAXN][MAXN],c[MAXN][MAXN],(*ans)[MAXN];
void multi(int (*a)[MAXN],int (*b)[MAXN],int (*c)[MAXN])
{
for(int i=1;i<=degree;i++)
{
for(int j=1;j<=degree;j++)
c[i][j]=0;
}
for(int i=1;i<=degree;i++)
{
for(int j=1;j<=degree;j++)
{
for(int k=1;k<=degree;k++)
{
c[i][j]+=(long long)a[i][k]*b[k][j]%MOD;
c[i][j]%=MOD;
}
}
}
}
void power(int (*t1)[MAXN],int h)
{
for(int i=1;i<=degree;i++)
for(int j=1;j<=degree;j++)
b[i][j]=0;
for(int i=1;i<=degree;i++)
b[i][i]=1;
int (*t2)[MAXN],(*t3)[MAXN];
t2=b,t3=c;
while(h)
{
if(h&1)
{multi(t1,t2,t3);
swap(t2,t3);
}
h>>=1;
multi(t1,t1,t3);
swap(t1,t3);
}
if(t2!=a)
{
memcpy(a,t2,sizeof a);
}
}
int inline getid(char C)
{
if(C=='A')return 0;
else if(C=='T')return 1;
else if(C=='C')return 2;
else return 3;
}
void insert(int r,char *s)
{
int len=strlen(s);
for(int i=0;i<len;i++)
{
int val=getid(s[i]);
if(trie[r].nxt[val]==0)
trie[r].nxt[val]=++tot;
r=trie[r].nxt[val];
}
trie[r].flag=1;//1表示结束节点
}
void build(int r)
{
trie[r].fail=r;
myq[tail++]=r;
int ch;
while(head<tail)
{
r=myq[head++];
for(int i=0;i<MAXC;i++)
{
ch=trie[r].nxt[i];
if(ch)myq[tail++]=ch;
if(r==root)
{
if(ch)
trie[ch].fail=root;
else trie[r].nxt[i]=root;
}
else
{
if(ch)
{trie[ch].fail=trie[trie[r].fail].nxt[i];
trie[ch].flag|=trie[trie[ch].fail].flag;
}
else
trie[r].nxt[i]=trie[trie[r].fail].nxt[i];
}
ch=trie[r].nxt[i];
if(trie[ch].flag!=1)
a[r][ch]++;
}
}
}
int main()
{
int n,m;
scanf("%d%d",&m,&n);
for(int i=0;i<m;i++)
{
scanf("%s",word);
insert(root,word);
}
build(root);
degree=tot;
power(a,n);
int ans=0;
for(int i=1;i<=degree;i++)
{ans+=a[1][i];
ans%=MOD;
}
printf("%d\n",ans);
}