4737: 组合数问题 lucas定理+数位DP

时间:2022-04-16 18:58:26

Description
组合数C(n,m)表示的是从n个物品中选出m个物品的方案数。举个例子,从(1,2,3)三个物品中选择两个物品可以有(1,2),(1,3),(2,3)这三种选择方法。根据组合数的定义,我们可以给出计算组合数C(n,m)的一般公式:
C(n,m)=n!/m!*(n?m)!
其中n!=1×2×?×n。(额外的,当n=0时,n!=1)
小葱想知道如果给定n,m和k,对于所有的0≤i≤n,0≤j≤min(i,m)有多少对(i,j)满足C(i,j)是k的倍数。

题解:

相信大家都知道lucas定理的形式: C n m = C n / p m / p C n % p m % p % p ,实际上,我们还可以这样理解:把 n , m 看做一个 p 进制的数,那么结果等于 C a 0 b 0 × C a 1 b 1 × C a k b k a i n 分解成 p 进制后的第 i 位, b 也类似。然后若结果不是 p 的倍数,当且仅当所有 a i >= b i ,因为这样的话所有 a i b i 都是小于 p 的,不会出现 p 的因子;否则就会有几项为 0 ,然后根据这个数位DP即可,转移有点恶心,要细心。

代码:

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=110;
const LL mod=1000000007LL,inv=500000004LL;
LL read()
{
    LL x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return x*f;
}
LL n,m,a[70],b[70],la,lb;
LL f[70][2][2];
//位数 是否小于n 是否小于m
LL sum(LL l,LL r){l%=mod;r%=mod;return (l+r)%mod*((r-l+1LL)%mod)%mod*inv%mod;}   
int main()
{
    int T=read();
    LL k=read();
    while(T--)
    {
        memset(a,0,sizeof(a));
        memset(b,0,sizeof(b));
        memset(f,0,sizeof(f));
        n=read(),m=read();la=lb=0;
        if(m>n)m=n;
        LL t,ans;
        if(n==m)ans=sum(1,n+1);
        else ans=(sum(1,m+1)+(n-m)%mod*((m+1)%mod)%mod)%mod;
        t=n;
        while(t)a[++la]=t%k,t/=k;
        for(int i=1;i<=(la>>1);i++)swap(a[i],a[la-i+1]);
        t=m;
        while(t)b[++lb]=t%k,t/=k;
        for(int i=1;i<=(lb>>1);i++)swap(b[i],b[lb-i+1]);
        int len=max(la,lb);
        if(la!=len)for(int i=la;i;i--)a[i+len-la]=a[i],a[i]=0;
        if(lb!=len)for(int i=lb;i;i--)b[i+len-lb]=b[i],b[i]=0;
        f[0][0][0]=1;
        for(int i=0;i<len;i++)
        for(int pn=0;pn<2;pn++)
        for(int pm=0;pm<2;pm++)
        if(f[i][pn][pm])
        {
            LL t=f[i][pn][pm];
            if(pn&&pm)f[i+1][1][1]=(f[i+1][1][1]+t*sum(1,k)%mod)%mod;
            if(pn&&!pm)
            {
                f[i+1][1][1]=(f[i+1][1][1]+t*(sum(1,b[i+1])+(k-b[i+1])*b[i+1]%mod)%mod)%mod;
                f[i+1][1][0]=(f[i+1][1][0]+t*(k-b[i+1])%mod)%mod;
            }
            if(!pn&&pm)
            {
                f[i+1][1][1]=(f[i+1][1][1]+t*sum(1,a[i+1])%mod)%mod;
                f[i+1][0][1]=(f[i+1][0][1]+t*(a[i+1]+1)%mod)%mod;
            }
            if(!pn&&!pm)
            {
                if(a[i+1]&&b[i+1])
                {
                    if(a[i+1]<=b[i+1])f[i+1][1][1]=(f[i+1][1][1]+t*sum(1,a[i+1])%mod)%mod;
                    else f[i+1][1][1]=(f[i+1][1][1]+t*(sum(1,b[i+1])+(a[i+1]-b[i+1])*b[i+1]%mod)%mod)%mod;
                }
                if(b[i+1])f[i+1][0][1]=(f[i+1][0][1]+t*min(a[i+1]+1,b[i+1])%mod)%mod;
                if(a[i+1]&&a[i+1]-1>=b[i+1])f[i+1][1][0]=(f[i+1][1][0]+t*(a[i+1]-b[i+1])%mod)%mod;
                if(a[i+1]>=b[i+1])f[i+1][0][0]=(f[i+1][0][0]+t)%mod;
            }
        }
        for(int i=0;i<2;i++)
        for(int j=0;j<2;j++)
        ans=((ans-f[len][i][j])%mod+mod)%mod;
        printf("%lld\n",ans);
    }
}