BZOJ 2588: Spoj 10628. Count on a tree [树上主席树]

时间:2022-11-29 22:33:03

2588: Spoj 10628. Count on a tree

Time Limit: 12 Sec  Memory Limit: 128 MB
Submit: 5217  Solved: 1233
[Submit][Status][Discuss]

Description

给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。

Input

第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。

Output

M行,表示每个询问的答案。

Sample Input

8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2

Sample Output

2
8
9
105
7

HINT

HINT:
N,M<=100000

每个节点建一棵主席树,维护树上前缀和
每次u和v之间的就是ls(u)+ls(v)-ls(lca(u,v))-ls(fa[lca(u,v)])
很多人用了dfs序的编号,没有必要
 
注意:
1.数组空间!!!!
2.lca用树剖的话,建树时dfs的顺序也要先重链!!!!!
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=1e5+,INF=1e9+;
int read(){
char c=getchar();int x=,f=;
while(c<''||c>''){if(c=='-')f=-; c=getchar();}
while(c>=''&&c<=''){x=x*+c-''; c=getchar();}
return x*f;
}
int n,Q,u,v,k,last,mp[N],m,w[N];
struct ques{
int u,v,k;
}q[N];
int Bin(int v){
int l=,r=m;
while(l<=r){
int mid=(l+r)>>;
if(mp[mid]==v) return mid;
else if(v<mp[mid]) r=mid-;
else l=mid+;
}
return -;
}
struct edge{
int v,ne;
}e[N<<];
int h[N],cnt;
inline void ins(int u,int v){
cnt++;
e[cnt].v=v;e[cnt].ne=h[u];h[u]=cnt;
cnt++;
e[cnt].v=u;e[cnt].ne=h[v];h[v]=cnt;
}
int size[N],fa[N],deep[N],mx[N],top[N],tid[N],tot;
void dfs(int u){
size[u]=;
for(int i=h[u];i;i=e[i].ne){
int v=e[i].v;
if(v==fa[u]) continue;
fa[v]=u;deep[v]=deep[u]+;
dfs(v);
size[u]+=size[v];
if(size[mx[u]]<size[v]) mx[u]=v;
}
}
void dfs(int u,int anc){
if(!u) return;
tid[u]=++tot;
top[u]=anc;
dfs(mx[u],anc);
for(int i=h[u];i;i=e[i].ne){
int v=e[i].v;
if(v!=fa[u]&&v!=mx[u]) dfs(v,v);
}
}
int lca(int x,int y){
while(top[x]!=top[y]){
if(deep[top[x]]<deep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(tid[x]>tid[y]) swap(x,y);
return x;
} struct node{
int l,r,size;
}t[N*];
int sz,root[N];
void insert(int &x,int l,int r,int num){
t[++sz]=t[x];x=sz;
t[x].size++;
if(l==r) return;
int mid=(l+r)>>;
if(num<=mid) insert(t[x].l,l,mid,num);
else insert(t[x].r,mid+,r,num);
}
void build(int u){
root[u]=root[fa[u]];
insert(root[u],,m,Bin(w[u]));
if(mx[u]) build(mx[u]);
for(int i=h[u];i;i=e[i].ne)
if(e[i].v!=fa[u]&&e[i].v!=mx[u]) build(e[i].v);
}
inline int ls(int x){return t[t[x].l].size;}
int query(int u,int v,int k){
int p=lca(u,v),q=fa[p];
u=root[u];v=root[v];p=root[p];q=root[q];
int l=,r=m;
while(l!=r){
int mid=(l+r)>>,_=ls(u)+ls(v)-ls(p)-ls(q);
if(k<=_) r=mid,u=t[u].l,v=t[v].l,p=t[p].l,q=t[q].l;
else l=mid+,u=t[u].r,v=t[v].r,p=t[p].r,q=t[q].r,k-=_;
}
return l;
}
int main(){
n=read();Q=read();
for(int i=;i<=n;i++) mp[i]=w[i]=read();
for(int i=;i<=n-;i++) u=read(),v=read(),ins(u,v);
dfs();dfs(,);
for(int i=;i<=Q;i++) q[i].u=read(),q[i].v=read(),q[i].k=read();
sort(mp+,mp++n);
m=;
for(int i=;i<=n;i++) if(mp[i]!=mp[i-]) mp[++m]=mp[i]; build();
for(int i=;i<=Q;i++){
u=last^q[i].u;v=q[i].v;k=q[i].k;
last=mp[query(u,v,k)];
printf("%d",last);
if(i!=Q) putchar('\n');
}
}