Time Limit: 1000MS | Memory Limit: 30000K | |
Total Submissions: 18550 | Accepted: 6069 |
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
Source
LouTiancheng@POJ树分治之点分治算法
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
using namespace std;
int n,k;
#define maxn 10005
typedef pair<int,int> pii;
#define mk(a,b) make_pair(a,b)
vector<pii>g[maxn<<2];
int rt,s[maxn],f[maxn],vis[maxn],d[maxn],ans,size,q[maxn];
/********
* s[i] -> i的size(算上本身子树的节点数)
* f[i] -> i的子节点中s[v]最大的值.
* 性质:f[i]最小的点是树的重心
*********/
//getroot() => 计算当前树的重心
void getroot(int u,int fa)
{
s[u] = 1;
f[u] = 0;
for(int i=0;i<g[u].size();i++)
{
int v = g[u][i].first;
if(v==fa||vis[v])continue; //考虑过
getroot(v,u);
s[u]+=s[v];
f[u] = max(f[u],s[v]);
}
f[u] = max(f[u],size-f[u]);
if(f[u]<f[rt])rt = u;
}
//dfs() => 计算深度
void dfs(int u,int fa)
{
s[u] = 1;
q[++(*q)]=d[u];
for(int i=0;i<g[u].size();i++)
{
int v = g[u][i].first;
if(v==fa||vis[v])continue;
d[v] = d[u] + g[u][i].second;
dfs(v,u);
s[u]+=s[v];
}
}
int calc(int u,int deep)
{
d[u] = deep, *q = 0;
dfs(u,0);
sort(q+1,q+1+(*q));
int l = 1,r = *q,ret = 0;
while(l<r)
{
if(q[l]+q[r]<=k)ret+=r-l++;
else r--;
}
return ret;
}
void solve(int u)
{
vis[u] = 1;
ans += calc(u,0);//加上 d[i]+d[j]<=k的(i,j)对数
for(int i=0;i<g[u].size();i++)
{
int v = g[u][i].first;
if(vis[v])continue;
ans-=calc(v,g[u][i].second);
f[0] = size = s[v];
getroot(v,rt = 0);
solve(rt);
}
}
int main(void)
{
int a,b,c;
while(scanf("%d%d",&n,&k)&&(n+k))
{
for(int i=0;i<=n;i++)g[i].clear();
memset(vis,0,sizeof(vis));
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&a,&b,&c);
g[a].push_back(mk(b,c));
g[b].push_back(mk(a,c));
}
f[0] = size = n;//初始化极大值
getroot(1,rt=0);
ans = 0;
solve(rt);
printf("%d\n",ans);
}
return 0;
}