[BZOJ 3196] 213平衡树 【线段树套set + 树状数组套线段树】

时间:2023-03-08 16:24:33
[BZOJ 3196] 213平衡树 【线段树套set + 树状数组套线段树】

题目链接:BZOJ - 3196

题目分析

区间Kth和区间Rank用树状数组套线段树实现,区间前驱后继用线段树套set实现。

为了节省空间,需要离线,先离散化,这样需要的数组大小可以小一些,可以卡过128MB = =

嗯就是这样,代码长度= =我写了260行......Debug了n小时= =

代码

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <set>
#include <map>
using namespace std; const int MaxN = 50000 + 5, MaxM = 50000 + 5, MN = 100000 + 15, INF = 999999999, MaxNode = 8000000 + 15; int n, m, Index, Used_Index, Top, Hash_Index;
int A[MaxN], Root[MaxN], T[MaxNode], Son[MaxNode][2], U[MaxN], C[MaxN], Que[MaxN + MaxM], TR[MaxN + MaxM]; struct Query
{
int f, L, R, k, Num, Pos;
} Q[MaxM]; map<int, int> M; multiset<int> S[MaxN * 4];
multiset<int>::iterator It; inline int gmin(int a, int b) {return a < b ? a : b;}
inline int gmax(int a, int b) {return a > b ? a : b;} void Add(int &x, int s, int t, int Pos, int Num)
{
if (x == 0) x = ++Index;
T[x] += Num;
if (s == t) return;
int m = (s + t) >> 1;
if (Pos <= m) Add(Son[x][0], s, m, Pos, Num);
else Add(Son[x][1], m + 1, t, Pos, Num);
} void Change(int x, int Pos, int Num)
{
for (int i = x; i <= n; i += i & -i)
Add(Root[i], 0, MN, Pos, Num);
} void Add_S(int x, int s, int t, int Pos, int Num)
{
S[x].insert(Num);
if (s == t) return;
int m = (s + t) >> 1;
if (Pos <= m) Add_S(x << 1, s, m, Pos, Num);
else Add_S(x << 1 | 1, m + 1, t, Pos, Num);
} void Del_S(int x, int s, int t, int Pos, int Num)
{
S[x].erase(S[x].find(Num));
if (s == t) return;
int m = (s + t) >> 1;
if (Pos <= m) Del_S(x << 1, s, m, Pos, Num);
else Del_S(x << 1 | 1, m + 1, t, Pos, Num);
} void Init_U(int x)
{
for (int i = x; i; i -= i & -i)
U[i] = Root[i];
} void Turn(int x, int f)
{
for (int i = x; i; i -= i & -i)
{
if (C[i] == Used_Index) break;
C[i] = Used_Index;
U[i] = Son[U[i]][f];
}
} int Get_LSum(int x)
{
int ret = 0;
for (int i = x; i; i -= i & -i)
ret += T[Son[U[i]][0]];
return ret;
} int Before(int x, int s, int t, int l, int r, int Num)
{
int ret;
if (l <= s && r >= t)
{
It = S[x].end();
It--;
if (*It < Num) return *It;
It = S[x].begin();
if (*It >= Num) return -INF;
It = S[x].lower_bound(Num);
It--;
return *It;
}
int m = (s + t) >> 1;
ret = -INF;
if (l <= m) ret = gmax(ret, Before(x << 1, s, m, l, r, Num));
if (r >= m + 1) ret = gmax(ret, Before(x << 1 | 1, m + 1, t, l, r, Num));
return ret;
} int After(int x, int s, int t, int l, int r, int Num)
{
int ret;
if (l <= s && r >= t)
{
It = S[x].upper_bound(Num);
if (It == S[x].end()) return INF;
else return *It;
}
int m = (s + t) >> 1;
ret = INF;
if (l <= m) ret = gmin(ret, After(x << 1, s, m, l, r, Num));
if (r >= m + 1) ret = gmin(ret, After(x << 1 | 1, m + 1, t, l, r, Num));
return ret;
} int main()
{
scanf("%d%d", &n, &m);
Top = 0; Index = 0;
for (int i = 1; i <= n; ++i)
{
scanf("%d", &A[i]);
Que[++Top] = A[i];
}
for (int i = 1; i <= m; ++i)
{
scanf("%d", &Q[i].f);
switch (Q[i].f)
{
case 1 :
scanf("%d%d%d", &Q[i].L, &Q[i].R, &Q[i].Num);
break;
case 2 :
scanf("%d%d%d", &Q[i].L, &Q[i].R, &Q[i].k);
break;
case 3 :
scanf("%d%d", &Q[i].Pos, &Q[i].Num);
break;
case 4 :
scanf("%d%d%d", &Q[i].L, &Q[i].R, &Q[i].Num);
break;
case 5 :
scanf("%d%d%d", &Q[i].L, &Q[i].R, &Q[i].Num);
break;
}
if (Q[i].f != 2) Que[++Top] = Q[i].Num;
}
sort(Que + 1, Que + Top + 1);
Hash_Index = 0;
for (int i = 1; i <= Top; ++i)
{
if (i > 1 && Que[i] == Que[i - 1]) continue;
M[Que[i]] = ++Hash_Index;
TR[Hash_Index] = Que[i];
}
for (int i = 1; i <= n; ++i)
{
A[i] = M[A[i]];
Change(i, A[i], 1);
Add_S(1, 1, n, i, A[i]);
}
int L, R, Pos, Num, k, Temp, l, r, mid;
for (int i = 1; i <= m; ++i)
{
if (Q[i].f != 2) Q[i].Num = M[Q[i].Num];
switch (Q[i].f)
{
case 1 :
L = Q[i].L; R = Q[i].R; Num = Q[i].Num;
Used_Index = 0;
Init_U(L - 1);
Init_U(R);
Temp = 0;
l = 0; r = MN;
while (l < r)
{
++Used_Index;
mid = (l + r) >> 1;
if (Num <= mid)
{
r = mid;
Turn(L - 1, 0);
Turn(R, 0);
}
else
{
Temp += Get_LSum(R) - Get_LSum(L - 1);
l = mid + 1;
Turn(L - 1, 1);
Turn(R, 1);
}
}
printf("%d\n", Temp + 1);
break; case 2 :
L = Q[i].L; R = Q[i].R; k = Q[i].k;
Init_U(L - 1);
Init_U(R);
Used_Index = 0;
Temp = 0;
l = 0; r = MN;
while (l < r)
{
++Used_Index;
mid = (l + r) >> 1;
Temp = Get_LSum(R) - Get_LSum(L - 1);
if (Temp >= k)
{
r = mid;
Turn(L - 1, 0);
Turn(R, 0);
}
else
{
l = mid + 1;
Turn(L - 1, 1);
Turn(R, 1);
k -= Temp;
}
}
printf("%d\n", TR[l]);
break; case 3 :
Pos = Q[i].Pos; Num = Q[i].Num;
Change(Pos, A[Pos], -1);
Del_S(1, 1, n, Pos, A[Pos]);
A[Pos] = Num;
Change(Pos, Num, 1);
Add_S(1, 1, n, Pos, Num);
break; case 4 :
L = Q[i].L; R = Q[i].R; Num = Q[i].Num;
printf("%d\n", TR[Before(1, 1, n, L, R, Num)]);
break; case 5 :
L = Q[i].L; R = Q[i].R; Num = Q[i].Num;
printf("%d\n", TR[After(1, 1, n, L, R, Num)]);
break;
}
}
return 0;
}