[poj1741][tree] (树/点分治)

时间:2023-03-09 14:40:01
[poj1741][tree] (树/点分治)

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.

Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.

Write a program that will count how many pairs which are valid for a given tree.

Input

The
input contains several test cases. The first line of each test case
contains two integers n, k. (n<=10000) The following n-1 lines each
contains three integers u,v,l, which means there is an edge between node
u and v of length l.

The last test case is followed by two zeros.

Output

For each test case output the answer on a single line.

Sample Input


Sample Output


Source

Solution

1.点分治+排序

先找出重心,求解答案。对于每个重心,计算出所有过该点的最短路径长度小于或等于k的点对,记此答案为ans1

由于这些点对中会出现如下情况:

即,设任意分治出的子树重心的儿子为p,可能出现两个p的儿子共用了p到重心的路径,不符合最短路径要求

为了减去这种情况,我们可以递归算出所有关于p的重复答案,计为ans2

ans1-sum(ans2)即为最后答案

16456848

  ksq2013 1741 Accepted 760K 172MS C++ 1547B 2017-01-07 13:50:58
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define N 10010
#define inf ~0U>>1
using namespace std;
int fst[N],ecnt,ans;
struct edge{
int v,w,nxt;
}e[N<<];
inline void link(int x,int y,int w){
e[++ecnt].v=y;
e[ecnt].w=w;
e[ecnt].nxt=fst[x];
fst[x]=ecnt;
}
bool vis[N];
int n,m,root,f[N],size[N],d[N],deep[N],sum,top;
void getroot(int x,int fa){
f[x]=;
size[x]=;
for(int j=fst[x];j;j=e[j].nxt)
if(e[j].v^fa&&!vis[e[j].v])
getroot(e[j].v,x),
size[x]+=size[e[j].v],
f[x]=max(f[x],size[e[j].v]);
f[x]=max(f[x],sum-size[x]);
if(f[x]<=f[root])root=x;
}
void getdeep(int x,int fa){
deep[++top]=d[x];
for(int j=fst[x];j;j=e[j].nxt)
if(e[j].v^fa&&!vis[e[j].v])
d[e[j].v]=d[x]+e[j].w,
getdeep(e[j].v,x);
}
int cal(int x,int v){
d[x]=v;top=;
getdeep(x,);
sort(deep+,deep++top);
int t=;
for(int l=,r=top;l<r;)
if(deep[l]+deep[r]<=m)
t+=r-l,l++;
else r--;
return t;
}
void solve(int x){
vis[x]=;
ans+=cal(x,);
for(int j=fst[x];j;j=e[j].nxt)
if(!vis[e[j].v])
ans-=cal(e[j].v,e[j].w),
root=,sum=size[e[j].v],
getroot(e[j].v,root),
solve(root);
}
int main(){
while(scanf("%d%d",&n,&m)&&n){
ans=ecnt=;memset(fst,,sizeof(fst));
memset(vis,,sizeof(vis));
for(int i=;i<n;i++){
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
link(x,y,w);link(y,x,w);
}
root=;f[]=inf;sum=n;
getroot(,);
solve(root);
printf("%d\n",ans);
}
return ;
}