
题目链接:http://poj.org/problem?id=2151
题意:
一次ACM比赛,有t支队伍,比赛共m道题。
第i支队伍做出第j道题的概率为p[i][j].
问你所有队伍都至少做出一道,并且有队伍做出至少n道的概率。
题解:
关于【至少】问题的表示。
对于每一支队伍:
mst[i][j] = P(第i支队伍做出至多j道题)
则 P(第i支队伍做出至少j道题) = 1 - mst[i][j-1]
对于所有队伍:
P(所有队伍至少答出一题) = ∏ (1 - mst[i][0])
P(所有队伍答题数在1到n-1) = ∏ (mst[i][n-1] - mst[i][0])
所以答案:
P(所有队伍至少答出一题,且有队伍做出至少n道) = P(所有队伍至少答出一题) - P(所有队伍答题数在1到n-1)
所以求mst数组好啦~~~
dp[i][j][k] = probability
i:第i支队伍
j:考虑到前j道题(包含j)
k:恰好做出k道
所以 mst[i][j] = sigma(dp[i][m][0 to j])
怎么求dp数组呢:
转移:dp[i][j][k] = dp[i][j-1][k-1]*p[i][j] + dp[i][j-1][k]*(1-p[i][j])
边界:dp[i][0][0] = 1, others = 0
所以这道题:先求dp,再求mst,最后统计ans。
AC Code:
// state expression:
// dp[i][j][k] = probability
// i: ith team
// j: jth question and before
// k: solved k questions
// mst[i][j]
// i: ith team
// j: all the teams solved at most j questions
//
// find the answer:
// P(all 1 to m) - P(all 1 to n-1)
//
// transferring:
// dp[i][j][k] = dp[i][j-1][k-1]*p[i][j] + dp[i][j-1][k]*(1-p[i][j])
//
// boundary:
// dp[i][0][0] = 1
// others = 0
//
// calculate:
// mst[i][j] = sigma dp[i][m][0 to j]
// P1 = pi (1 - mst[i][0])
// P2 = pi (mst[i][n-1] - mst[i][0])
//
// step:
// 1) cal dp
// 2) cal mst
// 3) cal ans
#include <iostream>
#include <stdio.h>
#include <string.h>
#define MAX_T 1005
#define MAX_N 35
#define MAX_M 35 using namespace std; int n,m,t;
double p1,p2;
double p[MAX_T][MAX_M];
double dp[MAX_T][MAX_M][MAX_M];
double mst[MAX_T][MAX_M]; void read()
{
for(int i=;i<=t;i++)
{
for(int j=;j<=m;j++)
{
cin>>p[i][j];
}
}
} void cal_dp()
{
memset(dp,,sizeof(dp));
for(int i=;i<=t;i++)
{
dp[i][][]=;
for(int j=;j<=m;j++)
{
for(int k=;k<=m;k++)
{
if(k->=) dp[i][j][k]+=dp[i][j-][k-]*p[i][j];
dp[i][j][k]+=dp[i][j-][k]*(-p[i][j]);
}
}
}
} void cal_mst()
{
// mst[i][j] = sigma dp[i][m][0 to j]
memset(mst,,sizeof(mst));
for(int i=;i<=t;i++)
{
for(int j=;j<=m;j++)
{
for(int k=;k<=j;k++)
{
mst[i][j]+=dp[i][m][k];
}
}
}
} void cal_ans()
{
// P1 = pi (1 - mst[i][0])
// P2 = pi (mst[i][n-1] - mst[i][0])
p1=1.0;
p2=1.0;
for(int i=;i<=t;i++)
{
p1*=(-mst[i][]);
p2*=(mst[i][n-]-mst[i][]);
}
} void solve()
{
cal_dp();
cal_mst();
cal_ans();
} void print()
{
printf("%.3f\n",p1-p2);
} int main()
{
while(cin>>m>>t>>n)
{
if(m== && t== && n==) break;
read();
solve();
print();
}
}