BZOJ 1444:[JSOI2009]有趣的游戏
首先我们建出Trie图,然后高斯消元。
我们设\(f_i\)表示经过第\(i\)个点的期望次数:
\[f_x=\sum i\cdot p_x(i)
\]
\]
\(p_x(i)\)表示经过第\(x\)个点\(i\)次的概率。我们设表示一个单词的节点为关键节点,则所有关键节点只会经过一次,也就是说\(f_{关键}=p_{关键}(1)\),也就是我们要求的答案。
\[\displaystyle f_x=\sum_{y与x相连}rate_{y\Rightarrow x}f_y
\]
\]
特别地\(\displaystyle f_1=\sum_{y与1相连}rate_{y\Rightarrow 1}f_y+1\),因为初始点在\(1\)。
\(rate_{y\Rightarrow x}\)就是能从\(y\)走到\(x\)的字母的出现概率。
根据这些等式列方程,再高斯消元就行了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 12
#define eps 1e-7
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
int n,l,m;
double rate[26];
double w[N*N][N*N];
char str[N];
namespace AC_automation {
int cnt=1;
int id[N];
struct trie {
int ch[26];
int w,fail;
}tr[N*N];
void Insert(char *s,int No) {
int len=strlen(s+1),now=1;
for(int i=1;i<=len;i++) {
int j=s[i]-'A';
if(!tr[now].ch[j]) tr[now].ch[j]=++cnt;
now=tr[now].ch[j];
}
id[No]=now;
tr[now].w=1;
}
queue<int>q;
void build_fail() {
q.push(1);
while(!q.empty()) {
int v=q.front();
q.pop();
for(int i=0;i<26;i++) {
if(!tr[v].ch[i]) continue ;
int sn=tr[v].ch[i],f=tr[v].fail;
while(f&&!tr[f].ch[i]) f=tr[f].fail;
if(!f) tr[sn].fail=1;
else tr[sn].fail=tr[f].ch[i];
q.push(sn);
}
}
}
int find_sn(int now,int j) {
while(now&&!tr[now].ch[j]) now=tr[now].fail;
return now?tr[now].ch[j]:1;
}
void build_matrix() {
for(int i=1;i<=cnt;i++) {
w[i][i]=-1;
if(tr[i].w) continue ;
else {
for(int j=0;j<m;j++) {
int sn=find_sn(i,j);
w[sn][i]+=rate[j];
}
}
}
w[1][cnt+1]=-1;
}
}
int sum;
double ans[N*N];
void Gauss(int n) {
for(int i=1;i<=n;i++) {
for(int j=i+1;j<=n;j++) {
if(fabs(w[i][i])<fabs(w[j][i])) swap(w[i],w[j]);
if(fabs(w[i][i])<eps) continue ;
for(int j=i+1;j<=n;j++) {
double tem=w[j][i]/w[i][i];
for(int k=i;k<=n+1;k++) w[j][k]-=tem*w[i][k];
}
}
}
for(int i=n;i>=1;i--) {
if(fabs(w[i][i])<eps) {ans[i]=0;continue ;}
for(int j=i+1;j<=n;j++) w[i][n+1]-=w[i][j]*ans[j];
ans[i]=w[i][n+1]/w[i][i];
}
}
int main() {
n=Get(),l=Get(),m=Get();
double a,b;
for(int i=0;i<m;i++) {
a=Get(),b=Get();
rate[i]=a/b;
}
for(int i=1;i<=n;i++) {
scanf("%s",str+1);
AC_automation::Insert(str,i);
}
AC_automation::build_fail();
AC_automation::build_matrix();
sum=AC_automation::cnt;
Gauss(sum);
for(int i=1;i<=n;i++) {
double a=ans[AC_automation::id[i]];
if(fabs(a)>0.005) cout<<fixed<<setprecision(2)<<a<<"\n";
else cout<<"0.00"<<"\n";
}
return 0;
}