题目描述
对于序列A,它的逆序对数定义为满足i<j,且Ai>Aj的数对(i,j)的个数。给1到n的一个排列,按照某种顺序依次删除m个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。
输入
输入第一行包含两个整数n和m,即初始元素的个数和删除的元素个数。以下n行每行包含一个1到n之间的正整数,即初始排列。以下m行每行一个正整数,依次为每次删除的元素。
输出
输出包含m行,依次为删除每个元素之前,逆序对的个数。
样例输入
5 4
1
5
3
4
2
5
1
4
2
样例输出
5
2
2
1
题解
个人不喜欢CDQ分治,所以写了个线段树套SBT
想法很自然,删除某个数,减少的贡献为它左边比它大的数的个数+它右边比它小的数的个数。外层维护区间线段树,内层维护平衡树(不用权值线段树因为卡空间),查找时找到对应区间在平衡树中查询;删除时把外层从根到对应叶子的每个节点在平衡树中删除掉。
然而写到一半CQzhangyu告诉我本题卡树套树,看了下Discuss发现还真是 = =。
于是赶紧把Treap换成SBT,然而还是TLE。
没办法,再把数组版改成结构体版,最终AC。
然而跑得还是比CDQ分治慢了5倍左右= =
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#define N 100010
#define lson l , mid , x << 1
#define rson mid + 1 , r , x << 1 | 1
using namespace std;
struct data
{
int l , r , w , si;
}a[N << 5];
int pos[N] , v[N] , root[N << 2] , tot;
inline int read()
{
int ret = 0; char ch = getchar();
while(ch < '0' || ch > '9') ch = getchar();
while(ch >= '0' && ch <= '9') ret = (ret << 3) + (ret << 1) + ch - '0' , ch = getchar();
return ret;
}
void zig(int &k)
{
int t = a[k].l;
a[k].l = a[t].r , a[t].r = k , a[t].si = a[k].si , a[k].si = a[a[k].l].si + a[a[k].r].si + 1;
k = t;
}
void zag(int &k)
{
int t = a[k].r;
a[k].r = a[t].l , a[t].l = k , a[t].si = a[k].si , a[k].si = a[a[k].l].si + a[a[k].r].si + 1;
k = t;
}
void maintain(int &k , bool flag)
{
if(!flag)
{
if(a[a[a[k].l].l].si > a[a[k].r].si) zig(k);
else if(a[a[a[k].l].r].si > a[a[k].r].si) zag(a[k].l) , zig(k);
else return;
}
else
{
if(a[a[a[k].r].r].si > a[a[k].l].si) zag(k);
else if(a[a[a[k].r].l].si > a[a[k].l].si) zig(a[k].r) , zag(k);
else return;
}
maintain(a[k].l , false) , maintain(a[k].r , true);
maintain(k , false) , maintain(k , true);
}
void add(int &k , int x)
{
if(!k) k = ++tot , a[k].w = x , a[k].si = 1;
else
{
a[k].si ++ ;
if(x < a[k].w) add(a[k].l , x);
else add(a[k].r , x);
maintain(k , x >= a[k].w);
}
}
void del(int &k , int x)
{
a[k].si -- ;
if(x < a[k].w) del(a[k].l , x);
else if(x > a[k].w) del(a[k].r , x);
else
{
if(!a[k].l || !a[k].r) k = a[k].l + a[k].r;
else
{
int t = a[k].r , last = k;
while(a[t].l) a[t].si -- , last = t , t = a[t].l;
if(t == a[last].l) a[last].l = a[t].r;
else a[last].r = a[t].r;
a[t].l = a[k].l , a[t].r = a[k].r , a[t].si = a[k].si , k = t;
}
}
}
int findl(int k , int x)
{
if(!k) return 0;
else if(x <= a[k].w) return findl(a[k].l , x);
else return findl(a[k].r , x) + a[a[k].l].si + 1;
}
int findr(int k , int x)
{
if(!k) return 0;
else if(x >= a[k].w) return findr(a[k].r , x);
else return findr(a[k].l , x) + a[a[k].r].si + 1;
}
void insert(int p , int a , int l , int r , int x)
{
add(root[x] , a);
if(l == r) return;
int mid = (l + r) >> 1;
if(p <= mid) insert(p , a , lson);
else insert(p , a , rson);
}
void erase(int p , int a , int l , int r , int x)
{
del(root[x] , a);
if(l == r) return;
int mid = (l + r) >> 1;
if(p <= mid) erase(p , a , lson);
else erase(p , a , rson);
}
int queryl(int b , int e , int a , int l , int r , int x)
{
if(b <= l && r <= e) return findl(root[x] , a);
int mid = (l + r) >> 1 , ans = 0;
if(b <= mid) ans += queryl(b , e , a , lson);
if(e > mid) ans += queryl(b , e , a , rson);
return ans;
}
int queryr(int b , int e , int a , int l , int r , int x)
{
if(b <= l && r <= e) return findr(root[x] , a);
int mid = (l + r) >> 1 , ans = 0;
if(b <= mid) ans += queryr(b , e , a , lson);
if(e > mid) ans += queryr(b , e , a , rson);
return ans;
}
int main()
{
int n , m , i , x;
long long ans = 0;
n = read() , m = read();
for(i = 1 ; i <= n ; i ++ )
v[i] = read() , insert(i , v[i] , 1 , n , 1) , ans += queryr(1 , i , v[i] , 1 , n , 1) , pos[v[i]] = i;
while(m -- )
{
x = read() , printf("%lld\n" , ans);
ans -= queryr(1 , pos[x] , x , 1 , n , 1) + queryl(pos[x] , n , x , 1 , n , 1);
erase(pos[x] , x , 1 , n , 1);
}
return 0;
}