#loj3089 [BJOI2019]奥术神杖

时间:2021-04-29 21:06:52

卡精度好题

最关键的一步是几何平均数的\(ln\)等于所有数字取\(ln\)后的算术平均值

那么现在就变成了一个很裸的01分数规划问题,一个通用的思路就是二分答案

现在来考虑二分答案的底层怎么写

把所有串拉出来造ac自动机,那么ac自动机上一个点的权值就是

fail树上这个点到祖先的树链上的字符串的权值之和

那么接下来设\(f(i,j)\)表示决策到了第\(i\)个字符,走到自动机节点\(j\)的最大收益大力dp即可

由于我们不希望均值是0,因此额外记录下有没有匹配上模式串即可

二分完了之后把mid调成l重新跑一遍dp,不然你会跑出无解的情况导致输出错误的答案

有个函数叫log2,精度比log高,这样你就可以少二分几次了

// luogu-judger-enable-o2
// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<queue>
using namespace std;const int N=4000;typedef long long ll;
typedef long double db;db mid;//const db eps=1e-8;
const db low_inf=-1e9;
int v[N<<1];int x[N<<1];int ct;int al[N];
db we[N];db sum[N];int num[N];int snum[N];
inline void add(int u,int V){
//printf("add %d %d\n",u,V);
v[++ct]=V;x[ct]=al[u];al[u]=ct;}
inline void pdfs(int u)
{
snum[u]+=num[u];
for(int i=al[u];i;i=x[i])
snum[v[i]]=snum[u],pdfs(v[i]);
}
inline void dfs(int u)
{
sum[u]+=we[u]-mid*num[u];
for(int i=al[u];i;i=x[i])
sum[v[i]]=sum[u],dfs(v[i]);
}
struct trie
{
int mp[N][12];int cnt;int fil[N];
inline int ins(int p,int c)
{return mp[p][c]=(mp[p][c])?mp[p][c]:++cnt;}
inline void build()
{
queue <int> q;
for(int i=1;i<=10;i++)
if(mp[1][i])fil[mp[1][i]]=1,q.push(mp[1][i]);
else mp[1][i]=1;
while(!q.empty())
{
int nw=q.front();q.pop();
for(int i=1;i<=10;i++)
if(mp[nw][i])fil[mp[nw][i]]=mp[fil[nw]][i],q.push(mp[nw][i]);
else mp[nw][i]=mp[fil[nw]][i];
}
for(int i=1;i<=cnt;i++)
if(fil[i])add(fil[i],i);
}
}tr;
struct data{int c;int lst;int pval;}fr[N][N];db dp[N][N];
char mde[N];int n;int m;char smde[N];int op[N];int hd;
inline void trans(int i,int j,int k)
{
int tw=tr.mp[j][k];
db tval=dp[i][j]+sum[tw];
if(tval>=dp[i+1][tw])
{
dp[i+1][tw]=tval;
fr[i+1][tw]=(data){k,j,fr[i][j].pval+snum[tw]};
}
}
inline void pritans()
{
db curmx=-0x3f3f3f3f;int st=-1;
for(int i=1;i<=tr.cnt;i++)
if(curmx<dp[n][i])
curmx=dp[n][i],st=i;
hd=0;
for(int i=n;i>=1;i--)
op[++hd]=fr[i][st].c,st=fr[i][st].lst;
for(int i=n;i>=1;i--)
printf("%d",op[i]-1);
printf("\n");
}
inline bool jud()
{
//for(int i=1;i<=tr.cnt;i++)
// printf("%.3Lf ",sum[i]);printf("\n");
sum[1]=0;
dfs(1);
for(int i=0;i<=n;i++)
for(int j=1;j<=tr.cnt;j++)
dp[i][j]=-0x3f3f3f3f;
dp[0][1]=0;
for(int i=0;i<n;i++)
for(int j=1;j<=tr.cnt;j++)
{
if(dp[i][j]<low_inf)continue;
if(mde[i+1]=='.')
for(int k=1;k<=10;k++)
trans(i,j,k);
else
trans(i,j,mde[i+1]-'0'+1);
}
db mx=-0x3f3f3f3f;
for(int i=1;i<=tr.cnt;i++)
if(fr[n][i].pval)mx=max(mx,dp[n][i]);
//printf("mx=%.10Lf\n",mx);
// pritans();
return mx>=0;
} int main()
{
//printf("%.10lf\n",exp(log(10)));
scanf("%d%d",&n,&m);
scanf("%s",mde+1);
tr.cnt=1;
for(int i=1,tmp;i<=m;i++)
{
scanf("%s",smde+1);
int p=1;
for(int j=1;smde[j]!='\0';j++)
p=tr.ins(p,smde[j]-'0'+1);
scanf("%d",&tmp);
we[p]+=log2(tmp);num[p]++;
// printf("%.10Lf\n",we[p]);
}
tr.build();
pdfs(1);
db l=0;db r=log2(1e9);
for(int i=1;i<=18;i++)
{
// printf("%.10Lf %.10Lf\n",l,r);
mid=(l+r)/2;
if(jud())l=mid;else r=mid;
}
mid=l;
jud();
pritans();
return 0;
}