loj #547. 「LibreOJ β Round #7」匹配字符串

时间:2023-03-09 07:17:43
loj #547. 「LibreOJ β Round #7」匹配字符串

#547. 「LibreOJ β Round #7」匹配字符串

题目描述

对于一个 01 串(即由字符 0 和 1 组成的字符串)sss,我们称 sss 合法,当且仅当串 sss 的任意一个长度为 mmm 的子串 s′s's​′​​,不为全 1 串。

请求出所有长度为 nnn 的 01 串中,有多少合法的串,答案对 655376553765537 取模。

输入格式

输入共一行,包含两个正整数 n,mn,mn,m。

输出格式

输出共一行,表示所求的和对 655376553765537 取模的结果。

样例

样例输入 1

5 2

样例输出 1

13

样例解释 1

以下是所有合法的串:

00000
00001
00010
00100
00101
01000
01001
01010
10000
10001
10010
10100
10101

样例输入 2

2018 7

样例输出 2

27940

#include<iostream>
#include<cstdio>
#include<cstring>
#define mod 65537
using namespace std;
int n,m;
struct node{
int n,m;
int a[][];
node(){memset(a,,sizeof(a));}
node operator * (const node &b)const{
node res;
res.n=n;res.m=b.m;
for(int i=;i<=n;i++)
for(int j=;j<=b.m;j++)
for(int k=;k<=m;k++)
res.a[i][j]+=1LL*a[i][k]*b.a[k][j]%mod;
return res;
}
};
bool check(int sta){
int cnt=;
for(int i=;i<=n;i++){
if(sta&(<<i-))cnt++;
else cnt=;
if(cnt>=m)return ;
}
return ;
}
node Pow(node x,int y){
node res;
res.n=;res.m=;
res.a[][]=;res.a[][]=;
while(y){
if(y&)res=res*x;
x=x*x;
y>>=;
}
return res;
}
void work1(){
node a;
a.n=;a.m=;
a.a[][]=;a.a[][]=;
node b;
b.n=b.m=;
b.a[][]=;b.a[][]=;b.a[][]=;
b=Pow(b,n-);
a=a*b;
int ans=(a.a[][]+a.a[][])%mod;
printf("%d",ans);
}
int main(){
scanf("%d%d",&n,&m);
if(m==){puts("");return ;}
if(m==){work1();return ;}
int ans=;
for(int sta=;sta<(<<n);sta++)
if(check(sta)){
ans++;
if(ans>=mod)ans-=mod;
}
printf("%d",ans);
return ;
}

13分 矩阵快速幂优化dp+枚举

#include<iostream>
#include<cstdio>
#include<cstring>
#define mod 65537
using namespace std;
long long n,m;
int tmp[],b[],c[],ans,inv[mod],fac[mod],bin[];
int Pow(int x,int y){
int res=;
while(y){
if(y&)res=1LL*res*x%mod;
x=1LL*x*x%mod;
y>>=;
}
return res;
}
int C(int n,int m){
if(m>n)return ;
return 1LL*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int Lucas(long long n,long long m){
if(!m)return ;
return 1LL*Lucas(n/mod,m/mod)*C(n%mod,m%mod)%mod;
}
void mul(int *b,int *c){
for(int i=;i<m;i++)
for(int j=;j<m;j++)
tmp[i+j]=(tmp[i+j]+1LL*b[i]*c[j]%mod)%mod;
for(int i=*m-;i>=m;i--)
for(int j=;j<=m;j++)
tmp[i-j]=(tmp[i-j]+tmp[i])%mod;
for(int i=;i<m;i++)b[i]=tmp[i],tmp[i]=tmp[i+m]=;
}
void solve1(){
bin[]=c[]=;
if(m==)b[]=;
else b[]=;
for(int i=;i<m;i++)
bin[i]=1LL**bin[i-]%mod;
while(n){
if(n&)mul(c,b);
mul(b,b);
n>>=;
}
for(int i=;i<m;i++)
ans=(ans+1LL*bin[i]*c[i]%mod)%mod;
cout<<ans;
}
int s(long long n){
long long base=Pow(Pow(,m+),mod-);
long long cc=Pow(,n);int res=;
for(int k=;k*(m+)<=n;k++){
long long tt=1LL*Lucas(n-k*m,k)*cc%mod;
tt=(k&)?mod-tt:tt;
res=(res+tt)>=mod?res-mod+tt:res+tt;
cc=1LL*cc*base%mod;
}
return res;
}
void solve2(){
fac[]=fac[]=;
for(int i=;i<mod;i++)fac[i]=1LL*fac[i-]*i%mod;
inv[mod-]=mod-;
for(int i=mod-;i>=;i--)inv[i-]=1LL*inv[i]*i%mod;
ans=s(n+)-s(n);
printf("%d\n",(ans<)?ans+mod:ans);
}
int main(){
cin>>n>>m;
if(m<=)solve1();
else solve2();
return ;
}

100分