【HDOJ】5632 Rikka with Array

时间:2022-10-11 07:14:34

1. 题目描述
$A[i]$表示二级制表示的$i$的数字之和。求$1 \le i < j \le n$并且$A[i]>A[j]$的$(i,j)$的总对数。

2. 基本思路
$n \le 10^300$。$n$这么大,显然只能用数位DP来做,我们可以预先处理一下将$n$表示成二进制,然后再进行DP。
$dp[i][j][k]$表示长度为i,两者$A$的差为$j$,状态为$k$的总数。
不妨令$|n| = l$,因此$j \in [-l, l]$,因此需要$+l$,将$j$映射到$[0,l*2]$上。
在考虑$k$有多少种情况?不妨令$(x,y), x<y$表示一对可行解。
(0) $Pref(x) < Pref(y), Pref(y) < Pref(n)$;
(1) $Pref(x) < Pref(y), Pref(y) == Pref(n)$;
(2) $Pref(x) == Pref(y), Pref(y) < Pref(n)$;
(3) $Pref(x) == Pref(y), Pref(y) == Pref(n)$;
上面4中情况分别对应$k \in [0, 3]$,剩下的就是状态转移就好了,还是挺简单的。总对数就是
\[\sum_{j = l+1}^{l*2}{dp[l][j][0]+dp[l][j][1]}\]
可以使用滚动数组优化,其实也可以不使用。

3. 代码

 /* 5632 */
#include <iostream>
#include <sstream>
#include <string>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <vector>
#include <deque>
#include <bitset>
#include <algorithm>
#include <cstdio>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <cctype>
#include <cassert>
#include <functional>
#include <iterator>
#include <iomanip>
using namespace std;
//#pragma comment(linker,"/STACK:102400000,1024000") #define sti set<int>
#define stpii set<pair<int, int> >
#define mpii map<int,int>
#define vi vector<int>
#define pii pair<int,int>
#define vpii vector<pair<int,int> >
#define rep(i, a, n) for (int i=a;i<n;++i)
#define per(i, a, n) for (int i=n-1;i>=a;--i)
#define clr clear
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define all(x) (x).begin(),(x).end()
#define SZ(x) ((int)(x).size())
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1 const int mod = ;
const int maxl = ;
const int maxn = ;
char ss[maxl];
int a[maxn];
int dp[][maxn<<][]; void solve() {
int l = , tmp;
int len = strlen(ss); rep(i, , len)
ss[i] -= ''; int b = ; while (b<len && ss[b]==)
++b;
if (b >= len) {
puts("");
return ;
} while () {
a[l++] = ss[len-] & ;
tmp = ;
rep(i, b, len) {
if (ss[i] & ) {
ss[i] = (tmp+ss[i])>>;
tmp = ;
} else {
ss[i] = (tmp+ss[i])>>;
tmp = ;
}
}
while (b<len && ss[b]==)
++b;
if (b >= len)
break;
} reverse(a, a+l); int l2 = l + l;
int p = , q = ; memset(dp, , sizeof(dp)); rep(ii, , a[]+) {
rep(jj, , a[]+) {
if (ii > jj)
continue; int nj = ii - jj + l;
int nk = (ii==jj) ? (jj==a[])| : (jj==a[]);
++dp[p][nj][nk];
}
} rep(i, , l) {
rep(j, , l2+) {
// i < j
rep(k, , ) {
if (!dp[p][j][k])
continue; int mn1, mn2, nj, nk; mn1 = ;
mn2 = (k&) ? a[i]:; rep(ii, , mn1+) {
rep(jj, , mn2+) {
nj = j + ii - jj;
nk = (k==) && (jj==a[i]);
if (nj >= )
dp[q][nj][nk] = (dp[q][nj][nk] + dp[p][j][k]) % mod;
}
}
}
// i = j
rep(k, , ) {
if (!dp[p][j][k])
continue; int mn, nj, nk; mn = (k&) ? a[i]:;
rep(ii, , mn+) {
rep(jj, , mn+) {
if (ii > jj)
continue; nj = j + (ii==) - (jj==);
if (k == ) {
nk = (ii<jj) ? :;
} else {
nk = (ii<jj) ? (jj==a[i]) : (jj==a[i])|;
}
if (nj >= )
dp[q][nj][nk] = (dp[q][nj][nk] + dp[p][j][k]) % mod;
}
}
}
}
p ^= ;
q ^= ;
memset(dp[q], , sizeof(dp[q]));
} int ans = ; rep(j, l+, l2+)
rep(k, , )
ans = (ans + dp[p][j][k]) % mod; printf("%d\n", ans);
} int main() {
ios::sync_with_stdio(false);
#ifndef ONLINE_JUDGE
freopen("data.in", "r", stdin);
freopen("data.out", "w", stdout);
#endif int t; scanf("%d", &t);
while (t--) {
scanf("%s", ss);
solve();
} #ifndef ONLINE_JUDGE
printf("time = %d.\n", (int)clock());
#endif return ;
}

4. 数据生成器

 import sys
import string
from random import randint, shuffle def GenData(fileName):
with open(fileName, "w") as fout:
t = 10
fout.write("%d\n" % (t))
ld = string.digits
for tt in xrange(t):
length = randint(200, 300)
L = [0] * length
for i in xrange(length):
L[i] = randint(0, 9)
L[0] = randint(1, 9)
fout.write("".join(map(str, L)) + "\n") def MovData(srcFileName, desFileName):
with open(srcFileName, "r") as fin:
lines = fin.readlines()
with open(desFileName, "w") as fout:
fout.write("".join(lines)) def CompData():
print "comp"
srcFileName = "F:\Qt_prj\hdoj\data.out"
desFileName = "F:\workspace\cpp_hdoj\data.out"
srcLines = []
desLines = []
with open(srcFileName, "r") as fin:
srcLines = fin.readlines()
with open(desFileName, "r") as fin:
desLines = fin.readlines()
n = min(len(srcLines), len(desLines))-1
for i in xrange(n):
ans2 = int(desLines[i])
ans1 = int(srcLines[i])
if ans1 > ans2:
print "%d: wrong" % i if __name__ == "__main__":
srcFileName = "F:\Qt_prj\hdoj\data.in"
desFileName = "F:\workspace\cpp_hdoj\data.in"
GenData(srcFileName)
MovData(srcFileName, desFileName)