POJ 2151 Check the difficulty of problems:概率dp【至少】

时间:2023-03-09 00:07:47
POJ 2151 Check the difficulty of problems:概率dp【至少】

题目链接:http://poj.org/problem?id=2151

题意:

  一次ACM比赛,有t支队伍,比赛共m道题。

  第i支队伍做出第j道题的概率为p[i][j].

  问你所有队伍都至少做出一道,并且有队伍做出至少n道的概率。

题解:

  关于【至少】问题的表示。

  

  对于每一支队伍:

    mst[i][j] = P(第i支队伍做出至多j道题)

    则 P(第i支队伍做出至少j道题) = 1 - mst[i][j-1]

  

  对于所有队伍:

    P(所有队伍至少答出一题) = ∏ (1 - mst[i][0])

    P(所有队伍答题数在1到n-1) = ∏ (mst[i][n-1] - mst[i][0])

    所以答案:

    P(所有队伍至少答出一题,且有队伍做出至少n道) = P(所有队伍至少答出一题) - P(所有队伍答题数在1到n-1)

  所以求mst数组好啦~~~

  dp[i][j][k] = probability

  i:第i支队伍

  j:考虑到前j道题(包含j)

  k:恰好做出k道

  所以 mst[i][j] = sigma(dp[i][m][0 to j])

  怎么求dp数组呢:

    转移:dp[i][j][k] = dp[i][j-1][k-1]*p[i][j] + dp[i][j-1][k]*(1-p[i][j])

    边界:dp[i][0][0] = 1, others = 0

  所以这道题:先求dp,再求mst,最后统计ans。

AC Code:

 // state expression:
// dp[i][j][k] = probability
// i: ith team
// j: jth question and before
// k: solved k questions
// mst[i][j]
// i: ith team
// j: all the teams solved at most j questions
//
// find the answer:
// P(all 1 to m) - P(all 1 to n-1)
//
// transferring:
// dp[i][j][k] = dp[i][j-1][k-1]*p[i][j] + dp[i][j-1][k]*(1-p[i][j])
//
// boundary:
// dp[i][0][0] = 1
// others = 0
//
// calculate:
// mst[i][j] = sigma dp[i][m][0 to j]
// P1 = pi (1 - mst[i][0])
// P2 = pi (mst[i][n-1] - mst[i][0])
//
// step:
// 1) cal dp
// 2) cal mst
// 3) cal ans
#include <iostream>
#include <stdio.h>
#include <string.h>
#define MAX_T 1005
#define MAX_N 35
#define MAX_M 35 using namespace std; int n,m,t;
double p1,p2;
double p[MAX_T][MAX_M];
double dp[MAX_T][MAX_M][MAX_M];
double mst[MAX_T][MAX_M]; void read()
{
for(int i=;i<=t;i++)
{
for(int j=;j<=m;j++)
{
cin>>p[i][j];
}
}
} void cal_dp()
{
memset(dp,,sizeof(dp));
for(int i=;i<=t;i++)
{
dp[i][][]=;
for(int j=;j<=m;j++)
{
for(int k=;k<=m;k++)
{
if(k->=) dp[i][j][k]+=dp[i][j-][k-]*p[i][j];
dp[i][j][k]+=dp[i][j-][k]*(-p[i][j]);
}
}
}
} void cal_mst()
{
// mst[i][j] = sigma dp[i][m][0 to j]
memset(mst,,sizeof(mst));
for(int i=;i<=t;i++)
{
for(int j=;j<=m;j++)
{
for(int k=;k<=j;k++)
{
mst[i][j]+=dp[i][m][k];
}
}
}
} void cal_ans()
{
// P1 = pi (1 - mst[i][0])
// P2 = pi (mst[i][n-1] - mst[i][0])
p1=1.0;
p2=1.0;
for(int i=;i<=t;i++)
{
p1*=(-mst[i][]);
p2*=(mst[i][n-]-mst[i][]);
}
} void solve()
{
cal_dp();
cal_mst();
cal_ans();
} void print()
{
printf("%.3f\n",p1-p2);
} int main()
{
while(cin>>m>>t>>n)
{
if(m== && t== && n==) break;
read();
solve();
print();
}
}