题目描述
给定一棵N(1<=N<=100000)个结点的带权树,每条边都有一个权值(为正整数,小于等于1001)。定义dis(u,v)为u,v两点间的最短路径长度,路径的长度定义为路径上所有边的权和。再给定一个K(1<=K<=10^9),如果对于不同的两个结点u,v,如果满足dist(u,v)<=K,则称(u,v)为合法点对。求合法点对个数。
题目大意
求树中距离小于k的点对个数
数据范围
对于50%的数据,n<=1000,k<=1000;
对于100%的数据,n<=100000,k<=10^9;
样例输入
5 4
1 2 3
1 3 1
1 4 2
3 5 1
样例输出
8
解题思路
不写了233
代码
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <ctime>
#define Maxn 100005
using namespace std;
inline int Getint(){int x=0,f=1;char ch=getchar();while('0'>ch||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while('0'<=ch&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
int fa[Maxn],h[Maxn],size[Maxn],dep[Maxn];
bool vis[Maxn];
int n,k,cnt=0,L,r,Min;
struct node{int to,next,v;}e[Maxn*2];
void AddEdge(int x,int y,int v){e[++cnt]=(node){y,h[x],v};h[x]=cnt;}
void Init(){
n=Getint(),k=Getint();
for(int i=1;i<n;i++){
int x,y,v;
x=Getint(),y=Getint(),v=Getint();
fa[y]=x;
AddEdge(x,y,v);
AddEdge(y,x,v);
}
memset(vis,0,sizeof(vis));
}
int dfssize(int u,int pre){
size[u]=1;
for(int p=h[u];p;p=e[p].next){
int y=e[p].to;
if(vis[y]||y==pre)continue;
size[u]+=dfssize(y,u);
}
return size[u];
}
void Getroot(int u,int pre,int tot,int &root){
int Max=tot-size[u];
for(int p=h[u];p;p=e[p].next){
int y=e[p].to;
if(vis[y]||y==pre)continue;
Getroot(y,u,tot,root);
Max=max(Max,size[y]);
}
if(Max<Min){
Min=Max;
root=u;
}
}
void Getlen(int u,int pre,int d){
dep[r++]=d;
for(int p=h[u];p;p=e[p].next){
int y=e[p].to;
if(vis[y]||y==pre)continue;
Getlen(y,u,d+e[p].v);
}
}
int Calc(int L,int r){
sort(dep+L,dep+r);
int ret=0,Pos=r-1;
for(int i=L;i<r;i++){
if(dep[i]>k)break;
while(Pos>=L&&dep[i]+dep[Pos]>k)Pos--;
ret+=Pos-L+1;
if(Pos>i)ret--;
}
return ret/2;
}
int Solve(int u){
int tot=dfssize(u,0),ret=0,root;
Min=0x7fffffff;
Getroot(u,0,tot,root);
vis[root]=true;
for(int p=h[root];p;p=e[p].next){
int y=e[p].to;
if(vis[y])continue;
ret+=Solve(y);
}
L=r=0;
for(int p=h[root];p;p=e[p].next){
int y=e[p].to;
if(vis[y])continue;
Getlen(y,root,e[p].v);
ret-=Calc(L,r);
L=r;
}
ret+=Calc(0,r);
for(int i=0;i<r;i++)
if(dep[i]<=k)ret++;
else break;
vis[root]=false;
return ret;
}
int main(){
Init();
cout<<Solve(1)<<"\n";
}