POJ1741 Tree(树分治——点分治)题解

时间:2023-03-08 20:14:54

题意:给一棵树,问你最多能找到几个组合(u,v),使得两点距离不超过k。

思路:点分治,复杂度O(nlogn*logn)。看了半天还是有点模糊。

显然,所有满足要求的组合,连接这两个点,他们必然经过他们的最小公共子树。

参考:【poj1741】Tree 树的点分治

代码:

#include<set>
#include<map>
#include<stack>
#include<cmath>
#include<queue>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
typedef long long ll;
const int maxn = + ;
const int seed = ;
const ll MOD = 1e9 + ;
const int INF = 0x3f3f3f3f;
using namespace std;
struct Edge{
int v, w, next;
}edge[maxn << ];
int dis[maxn], sz[maxn], maxv[maxn];
//到root距离,子树大小(包括自己),最大孩子
int tot, num, ans, n, k, Max, root, head[maxn]; //root重心
bool vis[maxn];
void addEdge(int u, int v, int w){
edge[tot].v = v;
edge[tot].w = w;
edge[tot].next = head[u];
head[u] = tot++;
} //子树大小
void dfs_sz(int u, int pre){
sz[u] = ;
maxv[u] = ;
for(int i = head[u]; i != -; i = edge[i].next){
int v = edge[i].v;
if(v == pre || vis[v]) continue;
dfs_sz(v, u);
sz[u] += sz[v];
if(maxv[u] < sz[v])
maxv[u] = sz[v];
}
} //找以u为根的子树的重心
void dfs_root(int r, int u, int pre){
maxv[u] = max(maxv[u], sz[r] - sz[u]);
//sz[r]-sz[u]是u上面部分的树的尺寸,跟u的最大孩子比,找到最大孩子的最小差值节点
if(maxv[u] < Max){
Max = maxv[u];
root = u;
}
for(int i = head[u]; i != -; i = edge[i].next){
int v = edge[i].v;
if(v == pre || vis[v]) continue;
dfs_root(r, v, u);
}
} //离重心距离
void dfs_dis(int u, int d, int pre){
dis[num++] = d;
for(int i = head[u]; i != -; i = edge[i].next){
int v = edge[i].v;
if(v == pre || vis[v]) continue;
dfs_dis(v, d + edge[i].w, u);
}
} //经过u的满足条件的组合的数量
int cal(int u, int d){
int ret = ;
num = ;
dfs_dis(u, d, -);
sort(dis, dis + num);
int i = , j = num - ;
while(i < j){
while(dis[i] + dis[j] > k && i < j)
j--;
ret += j - i;
//i到i+1~j满足
i++;
}
return ret;
} void dfs(int u){
Max = n;
dfs_sz(u, -);
dfs_root(u, u, -);
ans += cal(root, );
vis[root] = true;
for(int i = head[root]; i != -; i = edge[i].next){
int v = edge[i].v;
if(!vis[v]){
ans -= cal(v, edge[i].w);
dfs(v);
}
}
} void init(){
tot = ans = ;
memset(head, -, sizeof(head));
memset(vis, false, sizeof(vis));
} int main(){
while(scanf("%d%d", &n, &k) && n + k){
init();
int u, v, w;
for(int i = ; i < n - ; i++){
scanf("%d%d%d", &u, &v, &w);
addEdge(u, v ,w);
addEdge(v, u, w);
}
dfs();
printf("%d\n", ans);
}
return ;
}