BZOJ2142 礼物-扩展lucas

时间:2023-02-04 11:23:29

传送门
题意:

给定n及m个数a[1…m],再给定一个模数P,求

ni=1Caini1j=1aj%P

( 1n109,1m5 ,设 P=pc11pc22pc33pctt pi 为质数, 1pcii105 )

Solution:

扩展lucas定理模板题

这个算法理解起来比较方便,但是写的时候就不是那么好写了…

我们把p拆成 pc11pc22pc33pctt 的形式,这样我们只需要求出每个 C(n,m)%pcii ,再用CRT合并即可

而C(n,m)可以转化成阶乘及阶乘的逆元的乘积,由于扩欧求逆元要求互质,所以说我们计算阶乘时需要提取出和 pcii 互质的数

举一个其他博客用烂的栗子:

pi=3,ci=2,n=19

n!=1234567819

=[124578161719](369121518)

=[124578161719]36(123456)

发现后面的是 (n/pi)! ,于是递归即可。前半部分是以 pcii 为周期的 [124578]=[101113141617](mod9)

最后会有孤立出来的19,可以证明孤立出的长度不超过 pcii ,暴力算即可

至于剩下的 36 之类的数,我们只要计算出n!,m!,(n-m)!里含有多少个 pi (不妨设a,b,c),那么a-b-c就是C(n,m)中p的个数,直接算一下就行。

代码:

#include<cstdio>
#include<iostream>
using namespace std;
long long mod;
long long ans,sum;
long long a[10];
int n,m;
int nump,p;
long long tot;
int fast_pow(int a,int x,int mo)
{
int ans=1;
for (;x;x>>=1,a=1ll*a*a%mo)
if (x&1) ans=1ll*ans*a%mo;
return ans;
}
int mi(int n,int pi,int pk)
{
if (n==0) return 1;
int ans=1;
for (int i=2;i<=pk;i++)
if (i%pi) ans=1ll*ans*i%pk;
ans=fast_pow(ans,n/pk,pk);
for (int i=2;i<=n%pk;i++)
if (i%pi) ans=1ll*ans*i%pk;
return 1ll*ans*mi(n/pi,pi,pk)%pk;
}
void exgcd(int a,int b,long long &x,long long &y)
{
if (!b) x=1,y=0;
else exgcd(b,a%b,y,x),y-=a/b*x;
}
int inv(int a,int mo)
{
if (a==0) return 0;
long long x=0,y=0;
exgcd(a,mo,x,y);
x=(x%mo+mo)%mo;
if (x==0) x+=mo;
return x;
}
long long C(int n,int m,int pi,int pk)
{
if (m>n) return 0;
int a=mi(n,pi,pk);
int b=mi(m,pi,pk);
int c=mi(n-m,pi,pk);
int num=0;
for (int i=n;i;i/=pi) num+=i/pi;
for (int i=m;i;i/=pi) num-=i/pi;
for (int i=n-m;i;i/=pi) num-=i/pi;
int ans=1ll*a*inv(b,pk)%pk*inv(c,pk)%pk*fast_pow(pi,num,pk)%pk;
return 1ll*ans*(mod/pk)%mod*inv(mod/pk,pk)%mod;
}
int main()
{
scanf("%lld",&mod);
scanf("%d%d",&n,&m);
for (int i=1;i<=m;i++)
scanf("%d",&a[i]),tot+=a[i];
if (n<tot) {printf("Impossible\n");return 0;}
ans=1;
for (int i=1;i<=m;i++)
{
n-=a[i-1];
long long P=mod;sum=0;
for (int j=2;j*j<=P;j++)
{
if (P%j==0)
{
p=1;
while (P%j==0) P/=j,p*=j;
sum=(sum+C(n,a[i],j,p))%mod;
}
}
if (P>1) sum=(sum+C(n,a[i],P,P))%mod;
ans=(ans*sum)%mod;
}
printf("%lld\n",ans);
}