最近在学习树的分治,算是比较难,而且代码量比较大的一块。随便拿一道题来就有上百行,故写一篇文章来总结一下这方面的框架。
POJ这一题应该算是树分治的入门题,顺便用这一题来详细说明树分治的一些具体内容。
http://poj.org/problem?id=1741
Description
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The last test case is followed by two zeros.
Output
Sample Input
5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0
Sample Output
8
题目大意:有一个有n个节点、带边权的树(n<=10000);只有一次询问,求树上路径长度<=k的所有路径条数。每个测试点有多组数据。
首先此题有显然的N^2做法:用rmq或者lca预处理,然后枚举所有点对,但这样显然超时。
因此,在这里我们考虑树上分治的做法:
1.转化问题:我们将这棵树转化为一棵有根树,这样就可以统计 经过一个根节点 且只经过其子树中的节点的 合法路径 的条数,将经过各个点的这样的路径数加起来既是答案;如图,绿色的路径就是一条我们要求的路径;
2.如何得出经过每一个根结点的合法路径数呢?对于每一棵子树(如图中方框中 以红色结点为根的子树),我们dfs求出这棵子树中所有节点到红色根节点的距离,然后将这些距离放在dep数组中sort一下,这样就可以显然地求出这棵子树中所有合法路径条数;
3.但这样求出的不是经过一个节点的合法路径数,而是整个子树中的合法路径数;这些路径中,有些路径并不经过根节点;显然,如果这样计算求和,会有很多路径被重复计算。为了去除重复,我们对红色根节点的所有儿子节点做和第2步相同的操作,并且将每个儿子的结果都剪掉;这样就剪掉了所有只经过子树中的结点、而不经过根节点的路径数。减去之后,剩下的自然就是我们要求的经过根节点的路径数。
以上就是我们计算路径条数时的主要思想。为了降低复杂度,就要降低每次dfs时子树中结点的个数。这时我们就用到分治的思想:递归处理,每次找到重心,进行2、3步的操作;再将这个点挖掉,对剩下的子树再找重心,进行2、3步的操作,以此递归。利用重心的性质,每个子树都至少减小到上一级子树的一半,于是复杂度就降到了log级别。
值得一提的是,第3步的去重思想,在树上分治的题中有广泛且灵活的应用;本题较为基础,初学者应该对本题有透彻的把握。
代码:1236K,235MS
#include<cstdio> #include<iostream> #include<cstring> #include<algorithm> #include<vector> using namespace std; const int inf=1e5+10; int head[inf],next[inf<<1],to[inf<<1],len[inf<<1],cnt; int maxn[inf],siz[inf],G,subsiz; bool vis[inf]; int dp[inf<<1],dep[inf<<1];//dp[]存储到根节点的距离;dep[]是用来sort的,dep[0]表示dep数组中元素的个数 int n,k,ans=0; void init(void){ memset(vis,false,sizeof vis); memset(head,0,sizeof head); cnt=0;ans=0; } void addedge(int u,int v,int w){ to[++cnt]=v;len[cnt]=w; next[cnt]=head[u];head[u]=cnt; } void getG(int u,int f){//找重心 siz[u]=1;maxn[u]=0; for (int i=head[u];i;i=next[i]){ int v=to[i];if (v!=f && !vis[v]){ getG(v,u); siz[u]+=siz[v]; maxn[u]=max(maxn[u],siz[u]); } }maxn[u]=max(maxn[u],subsiz-siz[u]); G=(maxn[u]<maxn[G])?u:G; } void dfs(int u,int f){//dfs确定每个点到根节点的距离 dep[++dep[0]]=dp[u]; for (int i=head[u];i;i=next[i]){ int v=to[i];if (v!=f && !vis[v]){ dp[v]=dp[u]+len[i]; dfs(v,u); } } } int calc(int u,int inidep){//inidep是这一点相对于根节点的初始距离 dep[0]=0; dp[u]=inidep; dfs(u,0); sort(dep+1,dep+1+dep[0]); int sum=0; for (int l=1,r=dep[0];l<r;){//计算合法点对数目 if (dep[l]+dep[r]<=k) {sum+=r-l;l++;} else r--; } return sum; } void divide(int g){ //递归,找到重心并以重心为根节点进行计算,再对子树递归处理 ans+=calc(g,0); vis[g]=true; for (int i=head[g];i;i=next[i]){ int v=to[i]; if (!vis[v]){ ans-=calc(v,len[i]); maxn[0]=subsiz=siz[v];G=0;getG(v,0); divide(G); } } } int main(){ while(scanf("%d%d",&n,&k)==2){ if (!n && !k) break; init(); for (int i=1,u,v,w;i<n;i++){ scanf("%d%d%d",&u,&v,&w); addedge(u,v,w);addedge(v,u,w); } subsiz=maxn[0]=n;G=0;getG(1,0); divide(G); printf("%d\n",ans); } return 0; }