[湖南集训] 谈笑风生
题目描述
设 T 为一棵有根树,我们做如下的定义:
• 设 a 和 b 为 T 中的两个不同节点。如果 a 是 b 的祖先,那么称“a 比 b 不知道高明到哪里去了”。
• 设 a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定常数 x,那么称“a 与 b 谈笑风生”。
给定一棵 n 个节点的有根树 T,节点的编号为 1 ∼ n,根节点为 1 号节点。你需要回答 q 个询问,询问给定两个整数 p 和 k,问有多少个有序三元组 (a; b; c) 满足:
a、 b 和 c 为 T 中三个不同的点,且 a 为 p 号节点;
a 和 b 都比 c 不知道高明到哪里去了;
a 和 b 谈笑风生。这里谈笑风生中的常数为给定的 k。
输入输出格式
输入格式:
输入文件的第一行含有两个正整数 n 和 q,分别代表有根树的点数与询问的个数。
接下来 n − 1 行,每行描述一条树上的边。每行含有两个整数 u 和 v,代表在节点 u 和 v 之间有一条边。
接下来 q 行,每行描述一个操作。第 i 行含有两个整数,分别表示第 i 个询问的 p 和 k。
输出格式:
输出 q 行,每行对应一个询问,代表询问的答案。
输入输出样例
输入样例#1:
5 3
1 2
1 3
2 4
4 5
2 2
4 1
2 3
输出样例#1:
3
1
3
说明
样例中的树如下图所示:
对于第一个和第三个询问,合法的三元组有 (2,1,4)、 (2,1,5) 和 (2,4,5)。
对于第二个询问,合法的三元组只有 (4,2,5)。
所有测试点的数据规模如下:
对于全部测试数据的所有询问, 1 ≤ p ≤ n, 1 ≤ k ≤ n.
Solution
tip:这道题可以花式做,线段树合并....(然而蒟蒻我并不会)
其实题目的信息可以概括为这么几个条件
1. a和b都是c的祖先节点
2. a和b不是同一个节点
3. a和b在树中的深度之差的绝对值不超过k
那么其实我们可以分类讨论一下,对于\(b\)在\(a\)的上面的情况,\(c\)只能是\(a\)的子树中的节点,又因为不能为\(a\),所以有\(size[a]-1\)种可能,而\(b\)既然在a的上方,那就只有\(min(k,dep[a]-1)\)种取值了
所以很明显,这一部分的\(ans=min(k,dep[a]-1)\times (size[a]-1)\)
那么如果\(b\)在\(a\)的下方,则\(c\)只能为\(b\)的子树节点,所以对于一个节点,它对答案的贡献显然为\(size[]-1\),我们把这个答案用前缀和记录下来,用一种类似差分的方法查询就行了
怎么搞?线段树啊,主席树啊随便你乱搞...反正我写的是主席树
主席树的话,因为我们要求贡献,那么肯定是要把\(size[]\)当权值插入了,那就只能把\(dfs\)序当下标建树了
查询的时候也像线段树那样查个差值就好了
然后两个部分的答案加起来就好了
注意1:如果当前点的深度已经是最大深度了,就代表不可能有节点c了,输出0
注意2:开 long long
Code
#include<bits/stdc++.h>
#define in(i) (i=read())
#define il extern inline
#define rg register
#define mid ((l+r)>>1)
#define Min(a,b) ((a)<(b)?(a):(b))
#define Max(a,b) ((a)>(b)?(a):(b))
#define lol long long
using namespace std;
const lol N=3e5+10;
lol read() {
lol ans=0, f=1; char i=getchar();
while (i<'0' || i>'9') {if(i=='-') f=-1; i=getchar();}
while (i>='0' && i<='9') ans=(ans<<1)+(ans<<3)+(i^48), i=getchar();
return ans*f;
}
lol n,m,cur,maxn,cnt,tot;
lol head[N],nex[N<<1],to[N<<1];
lol rt[N],dep[N],size[N],dfn[N];
struct Chair_Tree {
lol l,r,v;
}t[N<<5];
void add(lol a,lol b) {
to[++cur]=b,nex[cur]=head[a],head[a]=cur;
to[++cur]=a,nex[cur]=head[b],head[b]=cur;
}
void insert(lol &u,lol l,lol r,lol pre,lol pos,lol v) {
t[u=++tot]=t[pre], t[u].v+=v;
if(l==r) return;
if(pos<=mid) insert(t[u].l,l,mid,t[pre].l,pos,v);
else insert(t[u].r,mid+1,r,t[pre].r,pos,v);
}
lol query(lol u,lol v,lol l,lol r,lol left,lol right,lol ans=0) {
//cout<<l<<" "<<r<<" "<<mid<<endl;
if(left<=l && r<=right) return t[v].v-t[u].v;
if(left<=mid) ans+=query(t[u].l,t[v].l,l,mid,left,right);
if(mid<right) ans+=query(t[u].r,t[v].r,mid+1,r,left,right);
return ans;
}
void init(lol u,lol fa) {
size[u]=1;
for (lol i=head[u];i;i=nex[i]) {
if(to[i]==fa) continue;
dep[to[i]]=dep[u]+1;
maxn=Max(maxn,dep[to[i]]);
init(to[i],u);
size[u]+=size[to[i]];
}
}
void dfs(lol u,lol fa) {
dfn[u]+=++cnt;
insert(rt[cnt],1,maxn,rt[cnt-1],dep[u],size[u]-1);
for (lol i=head[u];i;i=nex[i]) {
if(to[i]==fa) continue;
dfs(to[i],u);
}
}
int main()
{
in(n), in(m);
for (lol i=1,a,b;i<n;i++)
in(a), in(b), add(a,b);
dep[1]=1, init(1,0); dfs(1,0);
for (lol i=1,p,k,ans;i<=m;i++) {
in(p), in(k);
ans=(size[p]-1)*Min(k,dep[p]-1);
ans+=query(rt[dfn[p]-1],rt[dfn[p]+size[p]-1],1,maxn,dep[p]+1,Min(dep[p]+k,maxn));
if(dep[p]==maxn) ans=0;
printf("%lld\n",ans);
}
}