bzoj 3598 [Scoi2014]方伯伯的商场之旅 数位dp

时间:2022-12-16 12:15:14

当位置向后移动时,可以发现答案加上前面一段所有数字之和减去后面一段所有数字之和。
然后前面一段所有数字之和单调不减,后面一段所有数字之和单调不增。
为了避免重复找最前面的满足的点。设答案的位置为a1,a1的下一个位置为a2。
bzoj 3598 [Scoi2014]方伯伯的商场之旅 数位dp
应该选取的位置满足 s1<a1+a2+s2     s1+a1a2+s2
然后对左边右边分别数位dp,维护个数,移动到一边的代价和。
最后枚举分界点,把两边合并起来。
感觉我写得很恶心,一坨式子。。。

#include <bits/stdc++.h>
using namespace std;
#define ll long long
ll L,R,ret;
int K,top;
int st[71];
ll f[71][310][2],f1[71][310][2];
ll g[71][310][2],g1[71][310][2];
void upd1(int i,int j,int k,int r1,int r2)
{
    f[i+1][j+k][r2]+=f[i][j][r1];
    f1[i+1][j+k][r2]+=f[i][j][r1]*(j+k)+f1[i][j][r1];
}
void upd2(int i,int j,int k,int r1,int r2)
{
    g[i-1][j+k][r2]+=g[i][j][r1];
    g1[i-1][j+k][r2]+=g[i][j][r1]*(j+k)+g1[i][j][r1];
}
void upd(int i,int s2,int a2,int a1,int r1,int r2)
{
    int t1=i-2<0 ? 0:i-2,t2=i+1;
    int l=max(a2-a1+s2,0),r=a1+a2+s2-1;
    if(r<0)return;
    if(l)
    {
        ret+=f[t1][s2][r1]*(g1[t2][r][r2]-g1[t2][l-1][r2])+
        (f1[t1][s2][r1]+f[t1][s2][r1]*(s2+a2))*(g[t2][r][r2]-g[t2][l-1][r2]);
    }
    else ret+=f[t1][s2][r1]*g1[t2][r][r2]+
        (f1[t1][s2][r1]+f[t1][s2][r1]*(s2+a2))*g[t2][r][r2];
}
ll solve(ll x)
{
    if(!x)return 0;
    top=0;ret=0;
    while(x)st[++top]=x%K,x/=K;
    memset(f,0,sizeof(f));
    memset(f1,0,sizeof(f1));
    memset(g,0,sizeof(g));
    memset(g1,0,sizeof(g1));
    f[0][0][0]=1;
    for(int i=0;i<top;i++)
        for(int j=0;j<=(K-1)*i;j++)
            for(int k=0;k<K;k++)
            {
                upd1(i,j,k,0,k>st[i+1]);
                upd1(i,j,k,1,k>=st[i+1]);
            }
    g[top+1][0][1]=1;
    for(int i=top+1;i;i--)
        for(int j=0;j<=(K-1)*(top+1-i);j++)
            for(int k=0;k<K;k++)
            {
                upd2(i,j,k,0,0);
                if(k<=st[i-1])
                    upd2(i,j,k,1,k==st[i-1]);
            }
    for(int i=0;i<=top+1;i++)
        for(int j=1;j<=(K-1)*top;j++)
            for(int k=0;k<=1;k++)
            {
                g[i][j][k]+=g[i][j-1][k];
                g1[i][j][k]+=g1[i][j-1][k];
            }
    for(int i=1;i<=top;i++)
        for(int j=0;j<=(K-1)*(i-2)||j<=0;j++)
            for(int k=0;k<K;k++)
                for(int u=0;u<K;u++)
                {
                    if(i==1&&k)continue;
                    upd(i,j,k,u,0,0);
                    upd(i,j,k,u,1,0);   
                    if(k==st[i-1]&&u==st[i])
                        upd(i,j,k,u,0,1);
                    else if(u<st[i]||(u==st[i]&&k<st[i-1]))
                    {
                        upd(i,j,k,u,1,1);
                        upd(i,j,k,u,0,1);
                    }
                }
    return ret;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%lld%lld%d",&L,&R,&K);
    printf("%lld\n",solve(R)-solve(L-1));
    return 0;
}