【bzoj3626】[LNOI2014]LCA 树链剖分+线段树

时间:2021-05-08 20:34:19

题目描述

给出一个n个节点的有根树(编号为0到n-1,根节点为0)。一个点的深度定义为这个节点到根的距离+1。
设dep[i]表示点i的深度,LCA(i,j)表示i与j的最近公共祖先。
有q次询问,每次询问给出l r z,求sigma_{l<=i<=r}dep[LCA(i,z)]。
(即,求在[l,r]区间内的每个节点i与z的最近公共祖先的深度之和)

输入

第一行2个整数n q。
接下来n-1行,分别表示点1到点n-1的父节点编号。
接下来q行,每行3个整数l r z。

输出

输出q行,每行表示一个询问的答案。每个答案对201314取模输出

样例输入

5 2
0
0
1
1
1 4 3
1 4 2

样例输出

8
5


题解

树链剖分+线段树

考虑两点LCA的深度,可以看作两个点到根节点的路径交的长度(点的个数)。

而路径交的长度,又可以看作把一条路径上的点权值+1,然后查询另一条路径上的点的权值和。

于是本题转化为:把编号在$[l,r]$内的所有点到根路径上的点权值+1,再查询z到根的点权和。

于是我们可以把问题转化为前缀相减的形式,即求编号在$[1,p]$内的所有点到根路径上的点权值+1,查询z到根的点权和。

将拆成前缀相减后的询问离线,按照$p$排序。按照顺序直接处理对应编号,再查询即可。此时需要支持链上修改、链上查询,使用树链剖分+线段树即可。

时间复杂度$O(n\log^2n)$

#include <cstdio>
#include <algorithm>
#define N 50010
#define lson l , mid , x << 1
#define rson mid + 1 , r , x << 1 | 1
using namespace std;
struct data
{
int p , z , v , id;
data() {}
data(int P , int Z , int V , int Id) {p = P , z = Z , v = V , id = Id;}
bool operator<(const data &a)const {return p < a.p;}
}a[N << 1];
int n , head[N] , to[N] , next[N] , cnt , fa[N] , si[N] , bl[N] , pos[N] , tot , sum[N << 2] , tag[N << 2] , ans[N];
inline void add(int x , int y)
{
to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt;
}
void dfs1(int x)
{
int i;
si[x] = 1;
for(i = head[x] ; i ; i = next[i])
dfs1(to[i]) , si[x] += si[to[i]];
}
void dfs2(int x , int c)
{
int i , k = n;
bl[x] = c , pos[x] = ++tot;
for(i = head[x] ; i ; i = next[i])
if(si[to[i]] > si[k])
k = to[i];
if(k != n)
{
dfs2(k , c);
for(i = head[x] ; i ; i = next[i])
if(to[i] != k)
dfs2(to[i] , to[i]);
}
}
inline void pushdown(int l , int r , int x)
{
if(tag[x])
{
int mid = (l + r) >> 1;
sum[x << 1] += tag[x] * (mid - l + 1) , tag[x << 1] += tag[x];
sum[x << 1 | 1] += tag[x] * (r - mid) , tag[x << 1 | 1] += tag[x];
tag[x] = 0;
}
}
void update(int b , int e , int l , int r , int x)
{
if(b <= l && r <= e)
{
sum[x] += r - l + 1 , tag[x] ++ ;
return;
}
pushdown(l , r , x);
int mid = (l + r) >> 1;
if(b <= mid) update(b , e , lson);
if(e > mid) update(b , e , rson);
sum[x] = sum[x << 1] + sum[x << 1 | 1];
}
int query(int b , int e , int l , int r , int x)
{
if(b <= l && r <= e) return sum[x];
pushdown(l , r , x);
int mid = (l + r) >> 1 , ans = 0;
if(b <= mid) ans += query(b , e , lson);
if(e > mid) ans += query(b , e , rson);
return ans;
}
void modify(int x)
{
while(bl[x]) update(pos[bl[x]] , pos[x] , 1 , n , 1) , x = fa[bl[x]];
update(1 , pos[x] , 1 , n , 1);
}
int solve(int x)
{
int ans = 0;
while(bl[x]) ans += query(pos[bl[x]] , pos[x] , 1 , n , 1) , x = fa[bl[x]];
return ans + query(1 , pos[x] , 1 , n , 1);
}
int main()
{
int m , i , l , r , x , h = 0;
scanf("%d%d" , &n , &m);
for(i = 1 ; i < n ; i ++ ) scanf("%d" , &fa[i]) , add(fa[i] , i);
dfs1(0) , dfs2(0 , 0);
for(i = 1 ; i <= m ; i ++ ) scanf("%d%d%d" , &l , &r , &x) , a[i] = data(l - 1 , x , -1 , i) , a[i + m] = data(r , x , 1 , i);
sort(a + 1 , a + 2 * m + 1);
for(i = 1 ; i <= 2 * m ; i ++ )
{
while(h <= a[i].p) modify(h++);
ans[a[i].id] += a[i].v * solve(a[i].z);
}
for(i = 1 ; i <= m ; i ++ ) printf("%d\n" , ans[i] % 201314);
return 0;
}