poj1741 Tree (求树上任意两点之间权值和小于k的个数)(树分治)

时间:2021-12-25 04:19:47

题意:给你n个节点的树和k,问在这个树上两点之间最近距离小于k的情况有多少种?
思路:
看了两天题解(有些还写错)和一篇关于树分治的论文分治算法在数的路径问题中的应用才知道这是一类我从来没有做过的思想,在树上利用重心分治的搞一下把O(n)的步骤优化到O(logn).
先分析:
假定选择一点1为根,那其他点到根的最短距离就有两种情况。
其一,它们在根的不同分支上,那他们的最近距离就是它们到它们的最近公共祖先的距离和。
其二,如果它们在根的同一个分支上,那么最近距离就是它们之间的距离。

我们分析到这,先求出其他点到根1点的距离用一个dfs求出,放进一个数组里depth[maxn],对它排一个序,这里面任意两点只有以上两种情况,来自1同一个分支,或是不同分支,而我们求得方案数一定是来自不同分支。直接求不好求,我们可以先求出
A:depth[i]+depth[j] <=k的方案数(不管i,j是否来自同一个分支)
B:depth[i]+depth[j]<=k的方案数(i,j来自同一个分支)
B怎么求呢?我们只要在求depth[]数组的过程在A的前提下加一个边的权值,这样就相当于控制i,j两条边的方向就是来自于这一条边的方向。
在讨论一下:对于有序depth[]数组怎么求方案数呢?类似尺取法,每次取一个区间平移。具体看代码,很好理解。
每次求完一颗子树方案数后,重心都会变换。这个写的时候需要注意。(开始用vector写邻接表没考虑到这一细节,W一上午)。
关于找重心,降低时间复杂度,在大牛的论文里证明的很详细。

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
template<int N,int M>//N点的个数,M边的个数
struct Graph
{
int top;
struct Vertex{
int head;
}V[N];
struct Edge{
int v,w,next;
}E[M];
void init(){
memset(V,-1,sizeof(V));
top = 0;
}
void add_edge(int u,int v,int w){
E[top].v = v;
E[top].next = V[u].head;
E[top].w=w;
V[u].head = top++;
}
};
const int maxn=10000+10;

Graph<maxn,maxn*2> g;
int n,k,num[maxn],mx[maxn],vis[maxn],ans;//num[i]表示以i为根的子树的大小,mx[i]表示去掉i点后剩下子树最大的大小

void init(){
g.init();
memset(vis,0,sizeof(vis));
}

void dfssize(int u,int f){//求每棵子树的大小
num[u]=1;
mx[u]=0;
for(int i=g.V[u].head;i!=-1;i=g.E[i].next){
int v=g.E[i].v;
if(vis[v]||v==f) continue;
dfssize(v,u);
num[u]+=num[v];
mx[u]=max(mx[u],num[v]);
}
}

int mi,root;//每次求子树的重心

void dfsroot(int r,int u,int f){
mx[u]=max(mx[u],num[r]-num[u]);//求子树重心里考虑另一个方向的子树大小
if(mx[u]<mi) mi=mx[u],root=u;
for(int i=g.V[u].head;i!=-1;i=g.E[i].next){
int v=g.E[i].v;
if(vis[v]||f==v) continue;
dfsroot(r,v,u);
}
}

int depth[maxn],d;

void dfs_depth(int u,int f,int dis){//求根下子孙到根的距离
depth[d++]=dis;
for(int i=g.V[u].head;i!=-1;i=g.E[i].next){
int v=g.E[i].v;
if(vis[v]||v==f) continue;
dfs_depth(v,u,dis+g.E[i].w);
}
}

int cal(int u,int dis){//计算方案数
int tem=0;
d=0;
dfs_depth(u,-1,dis);
sort(depth,depth+d);
int i=0,j=d-1;
while(i<j){
while(depth[i]+depth[j]>k&&i<j) j--;
tem+=j-i;
i++;
}
return tem;
}

void dfs_ans(int u){
mi=n;
dfssize(u,0);
dfsroot(u,u,0);
ans+=cal(root,0);
vis[root]=1;
for(int i=g.V[root].head;i!=-1;i=g.E[i].next){
int v=g.E[i].v;
if(vis[v]) continue;
ans-=cal(v,g.E[i].w);
dfs_ans(v);
}
}

int main(){
while(scanf("%d%d",&n,&k),n||k){
init();
for(int i=0;i<n-1;i++){
int a,b,c;scanf("%d%d%d",&a,&b,&c);
g.add_edge(a,b,c);
g.add_edge(b,a,c);
}
ans=0;
dfs_ans(1);
printf("%d\n",ans);
}
}