BZOJ2733 [HNOI2012]永无乡 平衡树启发式合并

时间:2022-12-16 13:31:49

首先因为题目中涉及到查询第K小值,所以用平衡树来维护每个连通分支的信息。

那么加边这个操作怎么实现呢?其实就是将任意的两个平衡树合并。给我们的直观感受是把小的树合并到大的树里比较高效。

事实上,这样做的话,所有合并操作可以在O(nlog^2n)之内解决。

为什么呢?可以这样来分析。每个节点经过一次合并操作以后,它所在的树的大小至少要加倍,那么也就是说至多一个节点被合并操作影响logn次,每次合并后的插入操作要O(logn)时间,共有n个节点,就得到了O(nlog^2n)的时间复杂度。

吐槽一下数据……刚开始我没判断加边操作的两边是否已经在同一个连通分支内,就直接把树复制了一遍……竟然也AC了。下面的代码是改正以后的代码。

//BZOJ2733
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<vector>
#include<queue>
#include<ctime>
#include<cstdlib>
using namespace std;
const int MAXN=100010;
struct Treap_Node
{
	int ch[2],key,dat,size,sub;
}Treap[MAXN<<5];
int p[MAXN],sz[MAXN],root[MAXN],ip[MAXN],n,m,in1,in2,q,tot;
char op[10];
int find(int x)
{
	if(p[x]==x) return x;
	p[x]=find(p[x]);
	return p[x];
}
inline void uni(int i,int j)
{
	sz[find(j)]+=sz[find(i)];
	p[find(i)]=find(j);
}
inline int cmp(int x,int tar)
{
	return (Treap[x].dat>tar)?0:1;
}
inline void maintain(int x)
{
	Treap[x].size=Treap[Treap[x].ch[0]].size+1+Treap[Treap[x].ch[1]].size;
}
inline void rotate(int &x,int d)
{
	int p=Treap[x].ch[d^1];
	Treap[x].ch[d^1]=Treap[p].ch[d];
	Treap[p].ch[d]=x;
	maintain(x);
	maintain(p);
	x=p;
}
void ins(int &x,int tar,int s)
{
	if(!x)
	{
		Treap[++tot].dat=tar,Treap[tot].sub=s,Treap[tot].key=rand();
		Treap[tot].size=1,x=tot;
		return;
	}
	int d=cmp(x,tar);
	ins(Treap[x].ch[d],tar,s);
	if(Treap[Treap[x].ch[d]].key>Treap[x].key) rotate(x,d^1);
	maintain(x);
}
int getKth(int x,int k)
{
	if(k<=Treap[Treap[x].ch[0]].size) return getKth(Treap[x].ch[0],k);
	k-=Treap[Treap[x].ch[0]].size+1;
	if(k<=0) return Treap[x].sub;
	else return getKth(Treap[x].ch[1],k);
}
void mergeto(int x,int &y)
{
	ins(y,Treap[x].dat,Treap[x].sub);
	if(Treap[x].ch[0]) mergeto(Treap[x].ch[0],y);
	if(Treap[x].ch[1]) mergeto(Treap[x].ch[1],y);
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) scanf("%d",&ip[i]);
	for(int i=1;i<=n;i++) p[i]=i,sz[i]=1;
	for(int i=1;i<=m;i++)
	{
		scanf("%d%d",&in1,&in2);
		if(find(in1)!=find(in2)) uni(in1,in2);
	}
	for(int i=1;i<=n;i++) ins(root[find(i)],ip[i],i);
	scanf("%d",&q);
	for(int i=1;i<=q;i++)
	{
		scanf("%s%d%d",op,&in1,&in2);
		if(op[0]=='B'&&find(in1)!=find(in2))
		{
			int s1=sz[find(in1)],s2=sz[find(in2)];
			if(s1>s2)
			{
				mergeto(root[find(in2)],root[find(in1)]);
				uni(in2,in1);
			}
			else
			{
				mergeto(root[find(in1)],root[find(in2)]);
				uni(in1,in2);
			}
		}
		else if(op[0]=='Q')
		{
			if(sz[find(in1)]<in2) puts("-1");
			else printf("%d\n",getKth(root[find(in1)],in2));
		}
	}
	return 0;
}