题目链接
题解
我们把每次复制出来的树看做一个点,那么大树实际上也就是一棵\(O(M)\)个点的树
所以我们只需求两遍树上距离:
大树上求距离,进入同一个点后在模板树上再求一次距离
讨论好一些情况即可
然后求子树第\(k\)大的点要用主席树
没了
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<map>
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define mp(a,b) make_pair<int,int>(a,b)
#define cls(s) memset(s,0,sizeof(s))
#define cp pair<int,int>
#define LL long long int
using namespace std;
const int maxn = 100005,maxm = 6000005,INF = 1000000000;
inline LL read(){
LL out = 0,flag = 1; char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
return out * flag;
}
LL S[maxn],R[maxn];
int n,m,Q,bin[50];
int h[maxn],ne = 1;
struct EDGE{int to,nxt;}ed[maxn << 1];
inline void build(int u,int v){
ed[++ne] = (EDGE){v,h[u]}; h[u] = ne;
ed[++ne] = (EDGE){u,h[v]}; h[v] = ne;
}
int sum[maxm],ls[maxm],rs[maxm],rt[maxn],tot;
void modify(int& u,int pre,int l,int r,int pos){
sum[u = ++tot] = sum[pre] + 1;
ls[u] = ls[pre]; rs[u] = rs[pre];
if (l == r) return;
int mid = l + r >> 1;
if (mid >= pos) modify(ls[u],ls[pre],l,mid,pos);
else modify(rs[u],rs[pre],mid + 1,r,pos);
}
int query(int u,int v,int l,int r,LL k){
if (l == r) return l;
int mid = l + r >> 1,t = sum[ls[u]] - sum[ls[v]];
if (t >= k) return query(ls[u],ls[v],l,mid,k);
return query(rs[u],rs[v],mid + 1,r,k - t);
}
int dfn[maxn],siz[maxn],dep[maxn],Fa[maxn][18],cnt;
void dfs(int u){
dfn[u] = ++cnt; siz[u] = 1;
modify(rt[cnt],rt[cnt - 1],1,n,u);
REP(i,17) Fa[u][i] = Fa[Fa[u][i - 1]][i - 1];
Redge(u) if ((to = ed[k].to) != Fa[u][0]){
Fa[to][0] = u; dep[to] = dep[u] + 1;
dfs(to);
siz[u] += siz[to];
}
}
int hh[maxn],nne = 1,Nxt[maxn << 1],To[maxn << 1];
int fa[maxn][18],N,lk[maxn],Dep[maxn];
LL d[maxn][18];
void DFS(int u){
for (int k = hh[u]; k; k = Nxt[k]){
Dep[To[k]] = Dep[u] + 1;
DFS(To[k]);
}
}
void Build(){
LL u,b,x,r,v;
S[1] = n; R[1] = 1; N = 1;
for (int i = 1; i <= m; i++){
u = read(); v = read();
b = lower_bound(S + 1,S + 1 + N,v) - S; r = R[b];
x = query(rt[dfn[r] + siz[r] - 1],rt[dfn[r] - 1],1,n,v - S[b - 1]);
N++;
fa[N][0] = b; d[N][0] = 1 + dep[x] - dep[r];
S[N] = S[N - 1] + siz[u]; R[N] = u; lk[N] = x;
nne++;
Nxt[nne] = hh[b]; To[nne] = N; hh[b] = nne;
}
DFS(1);
//REP(i,N) printf("block %d rt = %lld lk = %d d = %lld Dep = %d total = %lld\n",i,R[i],lk[i],d[i][0],Dep[i],S[i]);
REP(j,17) REP(i,N){
fa[i][j] = fa[fa[i][j - 1]][j - 1];
d[i][j] = d[i][j - 1] + d[fa[i][j - 1]][j - 1];
}
}
LL dis(int u,int v){
if (dep[u] < dep[v]) swap(u,v);
LL re = 0;
for (int i = 0,D = dep[u] - dep[v]; bin[i] <= D; i++)
if (D & bin[i]) re += bin[i],u = Fa[u][i];
if (u == v) return re;
for (int i = 17; ~i; i--)
if (Fa[u][i] != Fa[v][i]){
u = Fa[u][i];
v = Fa[v][i];
re += bin[i + 1];
}
return re + 2;
}
LL Dis(LL a,LL b,LL x,LL y){
if (Dep[a] < Dep[b]){
swap(a,b);
swap(x,y);
}
LL re = dep[x] - dep[R[a]] + dep[y] - dep[R[b]];
int D = Dep[a] - Dep[b],u = a,v = b;
for (int i = 0; bin[i] <= D; i++)
if (D & bin[i]){
re += d[u][i];
u = fa[u][i];
}
if (u == v){
D = Dep[a] - Dep[b] - 1; u = a;
re = dep[x] - dep[R[a]];
for (int i = 0; bin[i] <= D; i++)
if (D & bin[i]){
re += d[u][i];
u = fa[u][i];
}
re += 1;
x = lk[u];
re += dis(x,y);
}
else {
for (int i = 17; ~i; i--)
if (fa[u][i] != fa[v][i]){
re += d[u][i] + d[v][i];
u = fa[u][i];
v = fa[v][i];
}
re += 2;
x = lk[u]; y = lk[v];
re += dis(x,y);
}
return re;
}
void solve(){
LL u,v,a,b,x,y;
while (Q--){
u = read(); v = read();
a = lower_bound(S + 1,S + 1 + N,u) - S;
b = lower_bound(S + 1,S + 1 + N,v) - S;
x = query(rt[dfn[R[a]] + siz[R[a]] - 1],rt[dfn[R[a]] - 1],1,n,u - S[a - 1]);
y = query(rt[dfn[R[b]] + siz[R[b]] - 1],rt[dfn[R[b]] - 1],1,n,v - S[b - 1]);
//printf("(%lld,%lld) (%lld,%lld)\n",a,x,b,y);
if (a == b) printf("%lld\n",dis(x,y));
else printf("%lld\n",Dis(a,b,x,y));
}
}
int main(){
bin[0] = 1; for (int i = 1; i <= 25; i++) bin[i] = bin[i - 1] << 1;
n = read(); m = read(); Q = read();
for (int i = 1; i < n; i++) build(read(),read());
dfs(1);
Build();
solve();
return 0;
}