UVA 11990 `Dynamic'' Inversion CDQ分治, 归并排序, 树状数组, 尺取法, 三偏序统计 难度: 2

时间:2023-12-19 19:43:44

题目

https://uva.onlinejudge.org/index.php?option=com_onlinejudge&Itemid=8&page=show_problem&problem=3141

题意

一个1到n的排列,每次随机删除一个,问删除前的逆序数

思路

综合考虑,对每个数点,令value为值,pos为位置,time为出现时间(总时间-消失时间),明显是统计value1 > value2, pos1 < pos2, time1 < time2的个数

首先对其中一个轴排序,比如value,这样在归并过程中,左子树的value总是小于右子树的,可以分治。

当左右子树包含哪些数点已经确定后,可以用自下而上的归并排序使得子树上的数点按照第二维相对有序,方便用尺取法统计子树之间的逆序数。

第三维通过树状数组进行压缩,加快统计速度。

注意仅仅统计左子树对右子树的影响,就会错过右子树中的数点出现的比较晚的情况。因此需要统计右子树对左子树的影响,此时注意别把同一时间出现的重复计数。

感想

1. 注意long long!!!

2. BIT的上限要>=n!

3. 注意统计影响完成后需要清空树状数组(区间大小已经减少了所以可以浪费地使用),此时不能直接用memset清空整个数组,时间会成为O(n2),超时。

代码

时间: 0.250s

时间复杂度O(cnlogn)

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <queue>
#include <tuple>
#include <cassert> using namespace std; const int MAXN = int(4e5 + ); #define LEFT_CHILD(x) ((x) << 1)
#define RIGHT_CHILD(x) (((x) << 1) + 1)
#define FATHER(x) ((x) >> 1)
#define IS_LEFT_CHILD(x) (((x) & 1) == 0)
#define IS_RIGHT_CHILD(x) (((x) & 1) == 1)
#define BROTHER(x) ((x) ^ 1)
#define LOWBIT(x) ((x) & (-x)) #define LOCAL_DEBUG struct Node{
int value, pos, time;
}nodes[MAXN], tmpNodes[MAXN]; int timeCnt[MAXN * ];
long long revNum[MAXN];
int clearStack[MAXN];
int clearLen;
int n, m;
int bitLimit; int getHigherBit(int n) {
int x = ;
while (x < n) { x <<= ; }
return x;
} void update(int id) {
while (id <= bitLimit) {
if (timeCnt[id] == ) {
clearStack[clearLen++] = id;
}
timeCnt[id]++;
id += LOWBIT(id);
}
} void clearCnt() {
while (clearLen > ) {
timeCnt[clearStack[--clearLen]] = ;
}
} int countTimesSmaller(int id) {
if (id < )return ;
int sum = ;
int tmp = ;
while (id > ) {
sum += timeCnt[id];
id -= LOWBIT(id);
}
return sum;
} void merge_by_pos(int root_ind, int internal_l, int internal_r) {
int internal_mid = (internal_l + internal_r) >> ;
for (int i = internal_l; i <= internal_r; i++) {
tmpNodes[i] = nodes[i];
}
for (int i = internal_l, j = internal_mid + , ind = internal_l; ind <= internal_r; ) {
if (i > internal_mid) {
nodes[ind++] = tmpNodes[j++];
}
else if (j > internal_r) {
nodes[ind++] = tmpNodes[i++];
}
else if (tmpNodes[i].pos < tmpNodes[j].pos) {
nodes[ind++] = tmpNodes[i++];
}
else {
nodes[ind++] = tmpNodes[j++];
}
}
}
void cal(int root_ind, int internal_l, int internal_r) {
if (internal_l == internal_r)return;
int internal_mid = (internal_l + internal_r) >> ;
if(internal_l != internal_mid)cal(LEFT_CHILD(root_ind), internal_l, internal_mid);
if (internal_mid + != internal_r)cal(RIGHT_CHILD(root_ind), internal_mid + , internal_r);
// printf("L Node: %d[%d, %d] LC: %d[%d, %d], RC: %d[%d, %d]\n", root_ind, internal_l, internal_r, LEFT_CHILD(root_ind), internal_l, internal_mid, RIGHT_CHILD(root_ind), internal_mid + 1, internal_r);
for (int i = internal_l, j = internal_mid + ; i <= internal_mid; i++) {
while (j <= internal_r && nodes[i].pos > nodes[j].pos) {
update(nodes[j].time);
j++;
}
revNum[nodes[i].time] += countTimesSmaller(nodes[i].time);
// printf("L (%d, %d, %d): +%d\n", nodes[i].value, nodes[i].pos, nodes[i].time, countTimesSmaller(nodes[i].time));
}
clearCnt(); for (int i = internal_mid, j = internal_r; j > internal_mid; j--) {
while (i >= internal_l && nodes[i].pos > nodes[j].pos) {
update(nodes[i].time);
i--;
}
revNum[nodes[j].time] += countTimesSmaller(nodes[j].time - );
// printf("R (%d, %d, %d): +%d\n", nodes[j].value, nodes[j].pos, nodes[j].time, countTimesSmaller(nodes[j].time - 1));
}
clearCnt();
merge_by_pos(root_ind, internal_l, internal_r); } int main() {
#ifdef LOCAL_DEBUG
freopen("input.txt", "r", stdin);
freopen("output2.txt", "w", stdout);
#endif // LOCAL_DEBUG
for (int ti = ; scanf("%d%d", &n, &m) == ; ti++) {
bitLimit = getHigherBit(n);
for (int i = ; i <= n; i++) {
int tmp;
scanf("%d", &tmp);
nodes[tmp].value = tmp;
nodes[tmp].pos = i;
nodes[tmp].time = ;
}
for (int i = ; i <= m + ; i++) { revNum[i] = ; }
for (int i = ; i < m; i++) {
int tmp;
scanf("%d", &tmp);
nodes[tmp].time = m - i + ;
}
cal(, , n);
long long ans = ;
for (int i = ; i <= m + ; i++) { ans += revNum[i]; }
for (int i = ; i < m; i++) {
printf("%lld\n", ans);
ans -= revNum[m - i + ];
}
}
return ;
}