BZOJ 3653 谈笑风生

时间:2022-12-25 10:38:43

3653: 谈笑风生

Time Limit: 20 Sec  Memory Limit: 512 MB
Submit: 966  Solved: 391
[Submit][Status][Discuss]

Description

设T 为一棵有根树,我们做如下的定义:
? 设a和b为T 中的两个不同节点。如果a是b的祖先,那么称“a比b不知道
高明到哪里去了”。
? 设a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定
常数x,那么称“a 与b 谈笑风生”。
给定一棵n个节点的有根树T,节点的编号为1 到 n,根节点为1号节点。你需
要回答q 个询问,询问给定两个整数p和k,问有多少个有序三元组(a;b;c)满足:
1. a、b和 c为 T 中三个不同的点,且 a为p 号节点;
2. a和b 都比 c不知道高明到哪里去了;
3. a和b 谈笑风生。这里谈笑风生中的常数为给定的 k。

Input

第一行含有两个正整数n和q,分别代表有根树的点数与询问的个数。
接下来n - 1行,每行描述一条树上的边。每行含有两个整数u和v,代表在节点u和v之间有一条边。
接下来q行,每行描述一个操作。第i行含有两个整数,分别表示第i个询问的p和k。
1<=P<=N
1<=K<=N
N<=300000
Q<=300000
 

Output

输出 q 行,每行对应一个询问,代表询问的答案。

Sample Input

5 3
1 2
1 3
2 4
4 5
2 2
4 1
2 3

Sample Output

3
1
3

HINT

 

Hint:边要加双向

 

Source

线段树以深度为关键字维护size的和

x,y的答案  = size[x] * min(deep[x], y) + dfs序在l[x] + 1到r[x]之间且深度在deep[x] + 1到deep[x] + k之间的size和

主席树写掉

#include<queue>
#include<cstdio>
#include<cstring>
#include<iostream>
#define inf 1000000000
#define pa pair<int,int>
#define ll long long 
using namespace std;
int read()
{
	int x=0;char ch=getchar();
	while(ch<'0'||ch>'9')ch=getchar();
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x;
}
int n,Q,cnt,ind,sz;
int last[300005],q[300005];
int deep[300005],size[300005];
int l[300005],r[300005],root[300005];
int ls[6000005],rs[6000005];
ll sum[6000005];
struct edge{
	int to,next;
}e[600005];
void insert(int u,int v)
{
	e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
	e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}
void dfs(int x,int fa)
{
	l[x]=++ind;
	q[ind]=x;
	for(int i=last[x];i;i=e[i].next)
		if(e[i].to!=fa)
		{
			deep[e[i].to]=deep[x]+1;
			dfs(e[i].to,x);
			size[x]+=size[e[i].to]+1;
		}
	r[x]=ind;
}
void build(int &x,int y,int l,int r,int pos,int val)
{
	x=++sz;
	sum[x]=sum[y]+val;
	if(l==r)return;
	ls[x]=ls[y];rs[x]=rs[y];
	int mid=(l+r)>>1;
	if(pos<=mid)build(ls[x],ls[y],l,mid,pos,val);
	else build(rs[x],rs[y],mid+1,r,pos,val);
}
ll query(int k,int l,int r,int x,int y)
{
	if(y>r)y=r;
	if(!k)return 0;
	if(l==x&&y==r)return sum[k];
	int mid=(l+r)>>1;
	if(y<=mid)return query(ls[k],l,mid,x,y);
	else if(x>mid)return query(rs[k],mid+1,r,x,y);
	else return query(ls[k],l,mid,x,mid)+query(rs[k],mid+1,r,mid+1,y);
}
int main()
{
	n=read();Q=read();
	for(int i=1;i<n;i++)
	{
		int u=read(),v=read();
		insert(u,v);
	}
	dfs(1,0);
	int mx=0;
	for(int i=1;i<=n;i++)mx=max(deep[i],mx);
	for(int i=1;i<=n;i++)
		build(root[i],root[i-1],0,mx,deep[q[i]],size[q[i]]);
    while(Q--)
	{
		int P=read(),K=read();
		ll ans=0;
		ans+=(ll)size[P]*min(deep[P],K);
		ans+=query(root[r[P]],0,mx,deep[P]+1,deep[P]+K);
		ans-=query(root[l[P]-1],0,mx,deep[P]+1,deep[P]+K);
		printf("%lld\n",ans);
	}
	return 0;
}