bzoj 4539: [Hnoi2016]树

时间:2023-03-09 17:33:25
bzoj 4539: [Hnoi2016]树

Description

  小A想做一棵很大的树,但是他手上的材料有限,只好用点小技巧了。开始,小A只有一棵结点数为N的树,结
点的编号为1,2,…,N,其中结点1为根;我们称这颗树为模板树。小A决定通过这棵模板树来构建一颗大树。构建过
程如下:(1)将模板树复制为初始的大树。(2)以下(2.1)(2.2)(2.3)步循环执行M次(2.1)选择两个数字a,b,
其中1<=a<=N,1<=b<=当前大树的结点数。(2.2)将模板树中以结点a为根的子树复制一遍,挂到大树中结点b的下
方(也就是说,模板树中的结点a为根的子树复制到大树中后,将成为大树中结点b的子树)。(2.3)将新加入大树
的结点按照在模板树中编号的顺序重新编号。例如,假设在进行2.2步之前大树有L个结点,模板树中以a为根的子
树共有C个结点,那么新加入模板树的C个结点在大树中的编号将是L+1,L+2,…,L+C;大树中这C个结点编号的大小
顺序和模板树中对应的C个结点的大小顺序是一致的。下面给出一个实例。假设模板树如下图:

bzoj 4539: [Hnoi2016]树
根据第(1)步,初始的大树与模板树是相同的。在(2.1)步,假设选择了a=4,b=3。运行(2.2)和(2.3)后,得到新的
大树如下图所示
bzoj 4539: [Hnoi2016]树
现在他想问你,树中一些结点对的距离是多少。

Input

  第一行三个整数:N,M,Q,以空格隔开,N表示模板树结点数,M表示第(2)中的循环操作的次数,Q 表示询问数
量。接下来N-1行,每行两个整数 fr,to,表示模板树中的一条树边。再接下来M行,每行两个整数x,to,表示将模
板树中 x 为根的子树复制到大树中成为结点to的子树的一次操作。再接下来Q行,每行两个整数fr,to,表示询问
大树中结点 fr和 to之间的距离是多少。N,M,Q<=100000

Output

  输出Q行,每行一个整数,第 i行是第 i个询问的答案。

Sample Input

5 2 3
1 4
1 3
4 2
4 5
4 3
3 2
6 9
1 8
5 3

Sample Output

6
3
3

HINT

经过两次操作后,大树变成了下图所示的形状:

bzoj 4539: [Hnoi2016]树

结点6到9之间经过了6条边,所以距离为6;类似地,结点1到8之间经过了3条边;结点5到3之间也经过了3条边。

Source

树链剖分+倍增+主席树+二分;

真的是一道码得爽彻心扉的大码农题,在考场上写这个题真的就是在赌博...;

首先因为直接复制的话点数是n^2的,不可能存得下,我们很容易想到构一棵超级树,树上的每个节点都代表一个复制上去的子树;

对于超级树上的节点存储两个信息:子树的根,节点的序号范围;

我们考虑如何查询题目中所说的大树的点在超级树中的编号,这个通过二分可以找到在超级树的哪个节点上;

然后我们想知道他对应原树的哪个节点,因为子树内的编号顺序与原树相同,那么我们相当于查询该超级点所代表的子树的第k小编号即可,那么我们对原树按dfs序建主席树就能查询了;

然后我们考虑如何处理询问:
首先我们需要找到这两个点对应在超级树中的位置,如果在同一个超级点中,那么直接查询两点子在原树上的距离;

如果不是,每个点先跳到该超级点所代表子树的根,然后往上跳到超级树上lca的儿子(用倍增来搞),然后进入lca所代表的子树中,然后查询两点间的距离;(如果lca为其中某一个节点则特殊处理)

在询问过程中我们需要知道一些值:

1.在一棵子树内部移动的距离,这个对原树进行树链剖分就可以很好地得到;

2.超级树上的边权,设新的点x,连在超级树的点y所代表的子树的z节点的下面,那么x->y的边权应该代表从x的子树的根到y的子树的根的距离,所以为dis(z,y.rt)+1;

然后我们用距离的前缀和数组相减即得到了在超级树上移动的距离;

这样就可以嘴巴AC了,但细节是真的多,我不把原树当超级点搞就错了;

//MADE BY QT666
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
const int N=800050;
int n,m,Q,tot;
ll sum;
int head[N],to[N],nxt[N],cnt,Father[N];
ll deep[N],d[N],dep[N];
void lnk(int x,int y){
to[++cnt]=y,nxt[cnt]=head[x],head[x]=cnt;
to[++cnt]=x,nxt[cnt]=head[y],head[y]=cnt;
}
struct data{
int dfn[N],ed[N],id[N],tmp,son[N],top[N],fa[N],sz[N];
int rt[N],ls[N*4],rs[N*4],size[N*4],tt;
void dfs1(int x,int f){
sz[x]=1;d[x]=deep[x];Father[x]=fa[x];
for(int i=head[x];i;i=nxt[i]){
int y=to[i];if(y==f) continue;
fa[y]=x;deep[y]=deep[x]+1;dfs1(y,x);
sz[x]+=sz[y];
if(sz[y]>sz[son[x]]) son[x]=y;
}
}
void dfs2(int x,int f){
top[x]=f;dfn[x]=++tmp;id[tmp]=x;
if(son[x]) dfs2(son[x],f);
for(int i=head[x];i;i=nxt[i]){
int y=to[i];if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
ed[x]=tmp;
}
int Lca(int x,int y){
while(top[x]!=top[y]){
if(deep[top[x]]<deep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
return y;
}
int dis(int x,int y){
int lca=Lca(x,y);
return deep[x]+deep[y]-2*deep[lca];
}
void insert(int &y,int x,int l,int r,int v){
y=++tt;ls[y]=ls[x],rs[y]=rs[x],size[y]=size[x]+1;
if(l==r) return;
int mid=(l+r)>>1;
if(v<=mid) insert(ls[y],ls[x],l,mid,v);
else insert(rs[y],rs[x],mid+1,r,v);
}
int query(int y,int x,int l,int r,int k){
if(l==r) return l;
int mid=(l+r)>>1;
if(size[ls[y]]-size[ls[x]]>=k) return query(ls[y],ls[x],l,mid,k);
else return query(rs[y],rs[x],mid+1,r,k-(size[ls[y]]-size[ls[x]]));
}
void build(){
fa[1]=1;dfs1(1,0);dfs2(1,1);
for(int i=1;i<=n;i++) insert(rt[i],rt[i-1],1,n,id[i]);
}
}tree;
struct date{
struct node{
ll l,r;int rt;
}xh[N];
struct pir{
int x,rt,pos;
};
int fa[N][20];
pir getpos(ll x){
int l=1,r=tot,ans;
while(l<=r){
int mid=(l+r)>>1;
if(x<=xh[mid].r) r=mid-1,ans=mid;
else l=mid+1;
}
int id=xh[ans].rt;
return (pir){tree.query(tree.rt[tree.ed[id]],tree.rt[tree.dfn[id]-1],1,n,x-xh[ans].l+1),xh[ans].rt,ans};
}
void add(ll x,ll y){
tot++;pir v=getpos(y);
xh[tot].l=sum+1,xh[tot].r=sum+tree.sz[x],xh[tot].rt=x;
sum+=tree.sz[x];
d[tot]=d[v.pos]+tree.dis(v.x,v.rt)+1;
dep[tot]=dep[v.pos]+1;
fa[tot][0]=v.pos;Father[tot]=v.x;
for(int j=1;j<=16;j++) fa[tot][j]=fa[fa[tot][j-1]][j-1];
}
int Lca(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
for(int i=18;i>=0;i--)
if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
if(u==v) return u;
for(int i=18;i>=0;i--)
if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
ll ask(ll x,ll y){
pir u=getpos(x),v=getpos(y);
if(dep[u.pos]<=dep[v.pos]) swap(u,v);
if(u.pos==v.pos) return tree.dis(u.x,v.x);
else{
int lca=Lca(u.pos,v.pos);
if(lca==v.pos){
int g=u.pos;
for(int i=18;i>=0;i--){
if(fa[g][i]&&dep[fa[g][i]]>dep[lca]) g=fa[g][i];
}
ll ret=d[u.pos]-d[g]+1+tree.dis(u.x,u.rt);
ret+=tree.dis(Father[g],v.x);
return ret;
}
else{
int p=u.pos,q=v.pos;
for(int i=18;i>=0;i--){
if(fa[p][i]&&dep[fa[p][i]]>dep[lca]) p=fa[p][i];
}
for(int i=18;i>=0;i--){
if(fa[q][i]&&dep[fa[q][i]]>dep[lca]) q=fa[q][i];
}
ll ret=d[u.pos]-d[p]+1+tree.dis(u.x,u.rt)+d[v.pos]-d[q]+1+tree.dis(v.x,v.rt);
ret+=tree.dis(Father[p],Father[q]);
return ret;
}
}
}
}super;
int main(){
scanf("%d%d%d",&n,&m,&Q);
for(int i=1;i<n;i++){
int x,y;scanf("%d%d",&x,&y);
lnk(x,y);
}
tree.build();tot=1;sum=n;dep[1]=1;super.fa[1][0]=1;
super.xh[1].l=1,super.xh[1].r=n,super.xh[1].rt=1;
for(int i=1;i<=m;i++){
ll x,y;scanf("%lld%lld",&x,&y);
super.add(x,y);
}
for(int i=1;i<=Q;i++){
ll x,y;scanf("%lld%lld",&x,&y);
printf("%lld\n",super.ask(x,y));
}
return 0;
}