题目大意
你有\(s_1\)种\(1\times 2\)的地砖,\(s_2\)种\(2\times 1\)的地砖。
记铺满\(m\times n\)的地板的方案数为\(f(m,n)\)。
给你\(m,l,r,s_1,s_2\),求\(\sum_{i=l}^rf(m,i)\)
\(m\leq 6,1\leq l\leq r\leq {10}^{2501}\)
题解
显然是状压DP。
显然可以矩阵快速幂。
怎么矩阵快速幂?
假设矩阵是\(2^m\times 2^m\)的,我们把矩阵扩大一行一列,记录前面算出的铺满\(m\)行\(i\)列的方案数。
转移在原来转移的基础上增加\(f_{i-1,2^m-1}\longrightarrow f_{i,2^m}\)和\(f_{i-1,2^m}\longrightarrow f_{i,2^m}\)
这样\(f_{i,m}\)就是\(\sum_{j=1}^{i-1}f_{j,2^m-1}\)。
然后就可以矩阵快速幂了。
显然这个矩阵快速幂是可以用特征多项式+倍增取模优化的。
求特征多项式可以用\(O(n^3)\)的方法,也可以用\(O(n^4)\)的方法。
然后就没了。
时间复杂度:\(O(8^m+4^m\log r)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
#include<vector>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
typedef vector<ll> poly;
const ll p=998244353;
ll fp(ll a,ll b)
{
ll s=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
s=s*a%p;
return s;
}
struct matrix
{
ll a[130][130];
int n,m;
matrix()
{
memset(a,0,sizeof a);
n=m=0;
}
ll *operator [](int x)
{
return a[x];
}
};
matrix operator *(matrix a,matrix b)
{
matrix c;
c.n=a.n;
c.m=b.m;
for(int i=0;i<a.n;i++)
for(int j=0;j<a.m;j++)
{
ll s=0;
for(int k=0;k<b.m;k++)
s=(s+a[i][k]*b[k][j])%p;
c[i][j]=s;
}
return c;
}
namespace yww
{
poly f[200];
void add(poly &a,poly b,ll s1,ll s2)
{
int n=b.size();
while(int(a.size())<n+1)
a.push_back(0);
for(int i=0;i<n;i++)
{
a[i]=(a[i]+b[i]*s2)%p;
a[i+1]=(a[i+1]+b[i]*s1)%p;
}
while(a.back()==0)
a.pop_back();
}
poly getpoly(matrix a,int n)
{
for(int i=0;i<=n;i++)
{
int j;
for(j=i+1;j<=n;j++)
if(a[j][i])
break;
if(j>n)
continue;
if(j!=i+1)
{
for(int k=i;k<=n;k++)
swap(a[i+1][k],a[j][k]);
for(int k=0;k<=n;k++)
swap(a[k][i+1],a[k][j]);
}
for(int j=i+2;j<=n;j++)
if(a[j][i])
{
ll v=fp(a[i+1][i],p-2)*a[j][i]%p;
for(int k=i;k<=n;k++)
a[j][k]=(a[j][k]-a[i+1][k]*v)%p;
for(int k=0;k<=n;k++)
a[k][i+1]=(a[k][i+1]+a[k][j]*v)%p;
}
}
f[n+1].push_back(1);
for(int i=n;i>=0;i--)
{
add(f[i],f[i+1],1,-a[i][i]);
ll v=1;
for(int j=i+2;j<=n+1;j++)
{
v=v*a[j-1][j-2]%p;
add(f[i],f[j],0,-v*a[i][j-1]%p);
}
}
return f[0];
}
}
matrix a;
int m,s1,s2,all;
void dfs(int x,int a1,int a2,ll s)
{
if(x>m+1)
return;
if(x>m)
{
a[all^a1][a2]=(a[all^a1][a2]+s)%p;
return;
}
dfs(x+1,a1,a2,s);
dfs(x+1,a1|(1<<(x-1)),a2|(1<<(x-1)),s*s1%p);
dfs(x+2,a1,a2|(3<<(x-1)),s*s2%p);
}
poly aa;
int len;
//poly f[20];
matrix g[200];
poly operator *(poly a,poly b)
{
int n=a.size()-1;
int m=b.size()-1;
poly c(n+m+1);
for(int j=0;j<=m;j++)
if(b[j])
for(int i=0;i<=n;i++)
c[i+j]=(c[i+j]+a[i]*b[j])%p;
return c;
}
poly operator %(poly a,poly b)
{
int n=a.size()-1;
int m=b.size()-1;
for(int i=n;i>=m;i--)
if(a[i])
{
ll v=a[i];
for(int j=0;j<=m;j++)
a[i-m+j]=(a[i-m+j]-b[j]*v)%p;
}
while(!a.back())
a.pop_back();
return a;
}
poly a1;
void init()
{
all=(1<<m)-1;
dfs(1,0,0,1);
a[all][all+1]=1;
a[all+1][all+1]=1;
a.n=a.m=all+2;
aa=yww::getpoly(a,all+1);
len=aa.size()-1;
// for(int i=0;i<=len;i++)
// c[i]=aa[i];
// f[0].push_back(1);
// for(int i=1;i<=9;i++)
// {
// f[i].resize(i+1);
// f[i][i]=1;
// f[i]=f[i]%aa;
//// for(auto v:f[i])
//// printf("%lld ",(v+p)%p);
//// printf("\n");
// }
g[0][0][all]=1;
g[0].n=1;
g[0].m=a.n;
for(int i=1;i<len;i++)
{
g[i]=g[i-1]*a;
// printf("%lld\n",(g[i][0][all+1]+p)%p);
}
a1.resize(2);
a1[0]=0;
a1[1]=1;
a1=a1%aa;
}
char l[10010],r[10010];
int e[100010];
int bit[100010];
poly f;
void calc(int n)
{
if(!n)
return;
calc(n-1);
f=f*f%aa;
if(bit[n])
f=f*a1%aa;
}
int solve(char *str,int b)
{
int n=strlen(str+1);
memset(e,0,sizeof e);
for(int i=1;i<=n;i++)
e[i]=str[n-i+1]-'0';
e[1]+=b;
int i;
for(i=1;e[i]>=10;i++)
{
e[i+1]+=e[i]/10;
e[i]%=10;
}
n=max(n,i);
memset(bit,0,sizeof bit);
int k=1;
for(int i=n;i>=1;i--)
{
int s=0;
int j;
for(j=1;j<=k||s;j++)
{
s+=bit[j]*10;
bit[j]=s&1;
s>>=1;
}
k=max(k,j);
bit[1]+=e[i];
for(j=1;bit[j]>=2;j++)
{
bit[j+1]+=bit[j]>>1;
bit[j]&=1;
}
k=max(k,j);
}
f.clear();
f.push_back(1);
reverse(bit+1,bit+k+1);
calc(k);
ll ans=0;
for(int i=0;i<f.size()&&i<len;i++)
{
// printf("%lld\n",(f[i]+p)%p);
ans=(ans+f[i]*g[i][0][all+1])%p;
}
return ans;
}
int main()
{
open("b");
scanf("%s%s",l+1,r+1);
scanf("%d%d%d",&m,&s1,&s2);
init();
ll ans1=solve(l,0);
ll ans2=solve(r,1);
ll ans=ans2-ans1;
ans=(ans%p+p)%p;
printf("%lld\n",ans);
return 0;
}