[湖南集训] 谈笑风生 (主席树)

时间:2021-01-23 10:38:27

[湖南集训] 谈笑风生

题目描述

设 T 为一棵有根树,我们做如下的定义:

• 设 a 和 b 为 T 中的两个不同节点。如果 a 是 b 的祖先,那么称“a 比 b 不知道高明到哪里去了”。

• 设 a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定常数 x,那么称“a 与 b 谈笑风生”。

给定一棵 n 个节点的有根树 T,节点的编号为 1 ∼ n,根节点为 1 号节点。你需要回答 q 个询问,询问给定两个整数 p 和 k,问有多少个有序三元组 (a; b; c) 满足:

  1. a、 b 和 c 为 T 中三个不同的点,且 a 为 p 号节点;

  2. a 和 b 都比 c 不知道高明到哪里去了;

  3. a 和 b 谈笑风生。这里谈笑风生中的常数为给定的 k。

输入输出格式

输入格式:

输入文件的第一行含有两个正整数 n 和 q,分别代表有根树的点数与询问的个数。

接下来 n − 1 行,每行描述一条树上的边。每行含有两个整数 u 和 v,代表在节点 u 和 v 之间有一条边。

接下来 q 行,每行描述一个操作。第 i 行含有两个整数,分别表示第 i 个询问的 p 和 k。

输出格式:

输出 q 行,每行对应一个询问,代表询问的答案。

输入输出样例

输入样例#1:

5 3
1 2
1 3
2 4
4 5
2 2
4 1
2 3

输出样例#1:

3
1
3

说明

样例中的树如下图所示:

[湖南集训] 谈笑风生 (主席树)

对于第一个和第三个询问,合法的三元组有 (2,1,4)、 (2,1,5) 和 (2,4,5)。

对于第二个询问,合法的三元组只有 (4,2,5)。

所有测试点的数据规模如下:

[湖南集训] 谈笑风生 (主席树)

对于全部测试数据的所有询问, 1 ≤ p ≤ n, 1 ≤ k ≤ n.

Solution

tip:这道题可以花式做,线段树合并....(然而蒟蒻我并不会)

其实题目的信息可以概括为这么几个条件

1. a和b都是c的祖先节点
2. a和b不是同一个节点
3. a和b在树中的深度之差的绝对值不超过k

那么其实我们可以分类讨论一下,对于\(b\)\(a\)的上面的情况,\(c\)只能是\(a\)的子树中的节点,又因为不能为\(a\),所以有\(size[a]-1\)种可能,而\(b\)既然在a的上方,那就只有\(min(k,dep[a]-1)\)种取值了

所以很明显,这一部分的\(ans=min(k,dep[a]-1)\times (size[a]-1)\)

那么如果\(b\)\(a\)的下方,则\(c\)只能为\(b\)的子树节点,所以对于一个节点,它对答案的贡献显然为\(size[]-1\),我们把这个答案用前缀和记录下来,用一种类似差分的方法查询就行了

怎么搞?线段树啊,主席树啊随便你乱搞...反正我写的是主席树

主席树的话,因为我们要求贡献,那么肯定是要把\(size[]\)当权值插入了,那就只能把\(dfs\)序当下标建树了
查询的时候也像线段树那样查个差值就好了

然后两个部分的答案加起来就好了

注意1:如果当前点的深度已经是最大深度了,就代表不可能有节点c了,输出0
注意2:开 long long

Code

#include<bits/stdc++.h>
#define in(i) (i=read())
#define il extern inline
#define rg register
#define mid ((l+r)>>1)
#define Min(a,b) ((a)<(b)?(a):(b))
#define Max(a,b) ((a)>(b)?(a):(b))
#define lol long long
using namespace std;

const lol N=3e5+10;

lol read() {
    lol ans=0, f=1; char i=getchar();
    while (i<'0' || i>'9') {if(i=='-') f=-1; i=getchar();}
    while (i>='0' && i<='9') ans=(ans<<1)+(ans<<3)+(i^48), i=getchar();
    return ans*f;
}

lol n,m,cur,maxn,cnt,tot;
lol head[N],nex[N<<1],to[N<<1];
lol rt[N],dep[N],size[N],dfn[N];

struct Chair_Tree {
    lol l,r,v;
}t[N<<5];

void add(lol a,lol b) {
    to[++cur]=b,nex[cur]=head[a],head[a]=cur;
    to[++cur]=a,nex[cur]=head[b],head[b]=cur;
}

void insert(lol &u,lol l,lol r,lol pre,lol pos,lol v) {
    t[u=++tot]=t[pre], t[u].v+=v;
    if(l==r) return;
    if(pos<=mid) insert(t[u].l,l,mid,t[pre].l,pos,v);
    else insert(t[u].r,mid+1,r,t[pre].r,pos,v);
}

lol query(lol u,lol v,lol l,lol r,lol left,lol right,lol ans=0) {
    //cout<<l<<" "<<r<<" "<<mid<<endl;
    if(left<=l && r<=right) return t[v].v-t[u].v;
    if(left<=mid) ans+=query(t[u].l,t[v].l,l,mid,left,right);
    if(mid<right) ans+=query(t[u].r,t[v].r,mid+1,r,left,right);
    return ans;
}

void init(lol u,lol fa) {
    size[u]=1;
    for (lol i=head[u];i;i=nex[i]) {
        if(to[i]==fa) continue;
        dep[to[i]]=dep[u]+1;
        maxn=Max(maxn,dep[to[i]]);
        init(to[i],u);
        size[u]+=size[to[i]];
    }
}

void dfs(lol u,lol fa) {
    dfn[u]+=++cnt;
    insert(rt[cnt],1,maxn,rt[cnt-1],dep[u],size[u]-1);
    for (lol i=head[u];i;i=nex[i]) {
        if(to[i]==fa) continue;
        dfs(to[i],u);
    }
}

int main()
{
    in(n), in(m);
    for (lol i=1,a,b;i<n;i++)
        in(a), in(b), add(a,b);
    dep[1]=1, init(1,0); dfs(1,0);
    for (lol i=1,p,k,ans;i<=m;i++) {
        in(p), in(k);
        ans=(size[p]-1)*Min(k,dep[p]-1);
        ans+=query(rt[dfn[p]-1],rt[dfn[p]+size[p]-1],1,maxn,dep[p]+1,Min(dep[p]+k,maxn));
        if(dep[p]==maxn) ans=0;
        printf("%lld\n",ans);
    }
}