BZOJ 3202 项链

时间:2022-03-18 14:02:01

题目连接:http://www.lydsy.com:808/JudgeOnline/problem.php?id=3202

题意:一个项链由n个珠子组成。每个珠子有三个面,每个面上有一个数字,要求每个珠子三个面的数字的Gcd值为1。三个数排序后相同的算作一种,即珠子(1,3,4)和珠子(3,1,4)是一样的。每个面的数字范围为[1,a]。项链中相邻珠子不能相同,旋转后相同的算作一种。求不同的项链个数。模M=1e9+7。

思路:首先可以算出不同珠子的种类,容斥莫比乌斯啥的,设有m种。接下来设长度为n的项链满足相邻不一样的方案数,设为f(n),那么f(n)=f(n-1)*(m-2)+f(n-2)*(m-1)。这样答案为:

BZOJ 3202 项链

const i64 M1=1000000007;
const i64 M2=1000000014000000049;
const int N=10000005;

int prime[N],cnt,tag[N];
int mou[N];

void init()
{
    int i,j;
    for(i=2;i<N;i++)
    {
        if(!tag[i]) prime[cnt++]=i,mou[i]=-1;
        for(j=0;j<cnt&&(i64)i*prime[j]<N;j++)
        {
            tag[i*prime[j]]=1;
            if(i%prime[j]) mou[i*prime[j]]=-mou[i];
            else
            {
                mou[i*prime[j]]=0;
                break;
            }
        }
    }
    mou[1]=1;
}

i64 n,a,mod;

i64 mul(i64 x,i64 y)
{
    x%=mod;
    i64 ans=0;
    if(y<0) x=-x,y=-y;
    while(y)
    {
        if(y&1) ans=(ans+x)%mod;
        x=(x+x)%mod;
        y>>=1;
    }
    if(ans<0) ans+=mod;
    return ans;
}

i64 C2(i64 x)
{
    if(x<2) return 0;
    return x*(x-1)/2%mod;
}

i64 C3(i64 x)
{
    if(x<3) return 0;
    i64 a=x,b=x-1,c=x-2;
    if(a%2==0) a>>=1;
    else b>>=1;

    if(a%3==0) a/=3;
    else if(b%3==0) b/=3;
    else c/=3;

    return mul(mul(a,b),c);
}

i64 cal()
{
    i64 ans=1;
    int i;
    for(i=1;i<=a;i++) if(mou[i])
    {
        ans+=mou[i]*C2(a/i)*2;
        ans%=mod;
        ans+=mou[i]*C3(a/i);
        ans%=mod;
    }
    return ans;
}

i64 myPow(i64 x,i64 y)
{
    i64 ans=1;
    while(y)
    {
        if(y&1) ans=mul(ans,x);
        x=mul(x,x);
        y>>=1;
    }
    return ans;
}

i64 m;

i64 eular(i64 x)
{
    i64 i;
    i64 ans=x;
    for(i=0;i<cnt&&(i64)prime[i]*prime[i]<=x;i++) if(x%prime[i]==0)
    {
        ans-=ans/prime[i];
        while(x%prime[i]==0) x/=prime[i];
    }
    if(x>1) ans-=ans/x;
    return ans;
}

i64 f(i64 n)
{
    if(n%2==0) return (myPow(m-1,n)+(m-1))%mod;
    return (myPow(m-1,n)-(m-1))%mod;
}

int main()
{
    init();
    int T=getInt();
    while(T--)
    {
        n=getInt();
        a=getInt();

        if(n%M1==0) mod=M2;
        else mod=M1;

        m=cal();

        i64 ans=0;
        int i;
        for(i=1;(i64)i*i<=n;i++) if(n%i==0)
        {
            ans+=mul(f(i),eular(n/i));
            ans%=mod;
            if(n/i!=i) ans+=mul(f(n/i),eular(i));
            ans%=mod;
        }
        if(mod==M1) ans=mul(ans,myPow(n%mod,mod-2));
        else
        {
            mod=M1;
            ans=ans/mod*myPow(n/mod,mod-2)%mod;
        }

        if(ans<0) ans+=mod;
        printf("%lld\n",ans);
    }
}