【loj6145】「2017 山东三轮集训 Day7」Easy 动态点分治+线段树

时间:2022-05-25 19:28:15

题目描述

给你一棵 $n$ 个点的树,边有边权。$m$ 次询问,每次给出 $l$ 、$r$ 、$x$ ,求 $\text{Min}_{i=l}^r\text{dis}(i,x)$ 。

$n,m\le 10^5$ 。


题解

动态点分治+线段树

分块做法太傻逼了我们把它丢到垃圾桶里。树上距离考虑动态点分治。

求出这棵树的点分树,对每一棵点分树子树开一棵动态开点编号线段树,维护编号在某区间内的点到当前点距离的最大值。

对于一次查询,我们在点分树从 $x$ 到根的路径上所有点对应的线段树上查询 $[l,r]$ 的最大值,$dis(i,x)+query(l,r,root_i)$ 的最大值极为答案。

这样做的正确性比较显然:

1. 每个 $[l,r]$ 内的点都属于这些子树的一个部分内,都被正确统计了一次。

2. 多余统计时,距离只会统计大,不会统计小,没有影响。

时间复杂度 $O(n\log^2 n)$

#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 100010
#define inf 1 << 30
using namespace std;
int head[N] , to[N << 1] , len[N << 1] , next[N << 1] , cnt , deep[N] , pos[N] , md[N << 1][20] , log[N << 1] , tot , si[N] , ms[N] , sum , root , vis[N] , fa[N];
int ls[N * 300] , rs[N * 300] , mn[N * 300] , rt[N] , tp;
inline void add(int x , int y , int z)
{
to[++cnt] = y , len[cnt] = z , next[cnt] = head[x] , head[x] = cnt;
}
void dfs(int x , int pre)
{
int i;
md[++tot][0] = deep[x] , pos[x] = tot;
for(i = head[x] ; i ; i = next[i])
if(to[i] != pre)
deep[to[i]] = deep[x] + len[i] , dfs(to[i] , x) , md[++tot][0] = deep[x];
}
inline int dis(int x , int y)
{
int t = deep[x] + deep[y] , k;
x = pos[x] , y = pos[y];
if(x > y) swap(x , y);
k = log[y - x + 1];
return t - 2 * min(md[x][k] , md[y - (1 << k) + 1][k]);
}
void getroot(int x , int pre)
{
int i;
si[x] = 1 , ms[x] = 0;
for(i = head[x] ; i ; i = next[i])
if(!vis[to[i]] && to[i] != pre)
getroot(to[i] , x) , si[x] += si[to[i]] , ms[x] = max(ms[x] , si[to[i]]);
ms[x] = max(ms[x] , sum - si[x]);
if(ms[x] < ms[root]) root = x;
}
void solve(int x)
{
int i;
vis[x] = 1;
for(i = head[x] ; i ; i = next[i])
if(!vis[to[i]])
sum = si[to[i]] , root = 0 , getroot(to[i] , 0) , fa[root] = x , solve(root);
}
void update(int p , int a , int l , int r , int &x)
{
if(!x) x = ++tp , mn[x] = inf;
mn[x] = min(mn[x] , a);
if(l == r) return;
int mid = (l + r) >> 1;
if(p <= mid) update(p , a , l , mid , ls[x]);
else update(p , a , mid + 1 , r , rs[x]);
}
int query(int b , int e , int l , int r , int x)
{
if(!x) return inf;
if(b <= l && r <= e) return mn[x];
int mid = (l + r) >> 1 , ans = inf;
if(b <= mid) ans = min(ans , query(b , e , l , mid , ls[x]));
if(e > mid) ans = min(ans , query(b , e , mid + 1 , r , rs[x]));
return ans;
}
int main()
{
int n , m , i , j , x , y , z , ans;
scanf("%d" , &n);
for(i = 1 ; i < n ; i ++ ) scanf("%d%d%d" , &x , &y , &z) , add(x , y , z) , add(y , x , z);
dfs(1 , 0);
for(i = 2 ; i <= tot ; i ++ ) log[i] = log[i >> 1] + 1;
for(i = 1 ; i <= log[tot] ; i ++ )
for(j = 1 ; j <= tot - (1 << i) + 1 ; j ++ )
md[j][i] = min(md[j][i - 1] , md[j + (1 << (i - 1))][i - 1]);
ms[0] = sum = n , root = 0 , getroot(1 , 0) , solve(root);
for(i = 1 ; i <= n ; i ++ )
for(j = i ; j ; j = fa[j])
update(i , dis(i , j) , 1 , n , rt[j]);
scanf("%d" , &m);
while(m -- )
{
scanf("%d%d%d" , &x , &y , &z) , ans = inf;
for(i = z ; i ; i = fa[i]) ans = min(ans , dis(i , z) + query(x , y , 1 , n , rt[i]));
printf("%d\n" , ans);
}
return 0;
}