测试地址:Tree
题目大意:给定一棵有N(N≤10000)个节点的带边权的树,我们称一个点对是合法的当且仅当两个点不相同且它们之间的距离≤K,求合法点对的数目。
做法:既是树分治的论文题,又是男人八题的其中一题,妙啊......
这道题需要用到树的点分治。
因为,我们很容易想到O(N^2)的暴力,然而对于N=10000的数量级根本束手无策。
那么要怎么办呢?
我们将这棵树按根节点分治,那么两点间的路径就可以分成两种情况:过根节点的和不过根节点的。不过根节点的情况可以递归进子树求出,我们就只用考虑过根节点的情况了。设点i与根节点的距离为dis[i],那么我们要找的就是:满足dis[i]+dis[j]≤K且i,j不属于同一棵子树的点对(i,j)的数目。但是这样太难算了,我们可以把问题分解成两个问题,令X=满足dis[i]+dis[j]≤K的点对(i,j)数目,Y=满足dis[i]+dis[j]≤K且i,j属于同一棵子树的点对(i,j)数目,那么原问题答案就是X-Y。我们发现X和Y都可以转化成“给定A,求满足A[i]+A[j]≤K的点对(i,j)数目”这个问题,这个问题的解决就只需要将A排序,设B[i]为使得A[i]+A[x]≤K的最大的x,根据单调性,B[i]是单调不增的,因此解决这个问题的复杂度为O(NlogN),可以证明没有更好的办法了。
但是,如果遇到极端情况,分治的深度可能会达到N,那么就比O(N^2)的暴力还差了,所以我们可以每次分治时,找到树的重心(使得以它为树的根节点时,节点数最多的子树节点数最小的点)作为根节点,可以证明这样分治的深度不大于logN,那么问题的均摊复杂度就优化到了O(Nlog^2 N),可以通过此题。
犯二的地方:各种细节写错导致重心没找对......TLE到爆炸......
以下是本人代码:
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define inf 1000000000
using namespace std;
int n,k,tot,first[10010],maxp,ans;
int a[10010],p[10010],fa[10010],siz[10010],dis[10010];
struct edge {int v,d,next;} e[20010];
bool vis[10010];
bool cmp(int a,int b)
{
return a<b;
}
void insert(int a,int b,int d)
{
e[++tot].v=b;
e[tot].d=d;
e[tot].next=first[a];
first[a]=tot;
}
void dfs(int v)
{
siz[v]=1;
a[++a[0]]=dis[v];
p[a[0]]=v;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa[v]&&!vis[e[i].v])
{
fa[e[i].v]=v;
dis[e[i].v]=dis[v]+e[i].d;
dfs(e[i].v);
siz[v]+=siz[e[i].v];
}
}
int find(int v)
{
int s,mx=inf;
a[0]=0;fa[v]=0;dis[v]=0;dfs(v);
for(int i=1;i<=siz[v];i++)
{
int x=p[i],maxsiz=0,sumsiz=0;
for(int j=first[x];j;j=e[j].next)
if (e[j].v!=fa[x])
{
maxsiz=max(maxsiz,siz[e[j].v]);
sumsiz+=siz[e[j].v];
}
maxsiz=max(maxsiz,siz[v]-sumsiz);
if (maxsiz<mx) mx=maxsiz,s=x;
}
return s;
}
int work(int v,int start)
{
a[0]=0;dis[v]=start;dfs(v);
sort(a+1,a+siz[v]+1,cmp);
int r=1,sum=0;
while(a[1]+a[r]<=k&&r<=siz[v]) r++;
r--;
for(int i=1;i<=siz[v];i++)
{
while(a[i]+a[r]>k) r--;
if (i>=r) break;
sum+=r-i;
}
return sum;
}
void solve(int v)
{
v=find(v);
fa[v]=0;
ans+=work(v,0);
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v]) ans-=work(e[i].v,dis[e[i].v]);
vis[v]=1;
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v]) solve(e[i].v);
}
int main()
{
while(scanf("%d%d",&n,&k)&&n)
{
tot=0;
memset(first,0,sizeof(first));
memset(vis,0,sizeof(vis));
for(int i=1,x,y,d;i<n;i++)
{
scanf("%d%d%d",&x,&y,&d);
insert(x,y,d),insert(y,x,d);
}
ans=0;
solve(1);
printf("%d\n",ans);
}
return 0;
}