普通dfs访问每个点对的复杂度是O(n^2)的,显然会超时。
考虑访问到当前子树的根节点时,统计所有经过根的点(u, v)满足:
dist(u) + dist(v) <= maxd,并且
belong(u)≠belong(v)(即u,v不在同一子树)。
这里说的距离指的是节点到跟的距离。
可以用作差法,即用所有满足条件的点对数减去那些在根节点为当前子树根节点的儿子节点的点对数。
上面一步可以用O(nlogn)的复杂度解决,即先排序再比较。
根节点子树可以递归解决,用树的点分治。
总复杂度上界是O(nlognlogn)。
http://poj.org/problem?id=1987
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + ;
const int inf = 0x3f3f3f3f;
struct Edge{
int to, next, w;
}edge[maxn << ];
int head[maxn], N;
int n, m, maxd;
bool vis[maxn];
int cnt[maxn];
int maxi[maxn];
int ans;
int mini, root, sum;
int buf[maxn], k; void addEdge(int u, int v, int w){
edge[N].next = head[u];
edge[N].to = v;
edge[N].w = w;
head[u] = N++;
} void init(){
N = ;
memset(head, -, sizeof head);
} void get_cnt(int u){
cnt[u] = ;
vis[u] = ;
maxi[u] = -;
buf[k++] = u;
for(int i = head[u]; i + ; i = edge[i].next){
int v = edge[i].to;
if(vis[v]) continue;
get_cnt(v);
cnt[u] += cnt[v];
maxi[u] = max(maxi[u], cnt[v]);
}
vis[u] = ;
} void get_dist(int u, int d){
vis[u] = ;
buf[k++] = d;
for(int i = head[u]; i + ; i = edge[i].next){
int v = edge[i].to;
if(vis[v]) continue;
get_dist(v, d + edge[i].w);
}
vis[u] = ;
} int get_buf_sum(int left, int right){
sort(buf + left, buf + right);
int tem = ;
for(int i = right - , j = left; i >= left + ; i--){
while(j < i && buf[i] + buf[j] <= maxd) ++j;
tem += min(i, j) - left;
}
return tem;
} void cal(int u){
k = ;
get_cnt(u);
mini = inf;
for(int i = ; i < k; i++){
int tem = max(cnt[u] - cnt[buf[i]], maxi[buf[i]]);
if(tem < mini) mini = tem, root = buf[i];
}
k = ;
vis[root] = ;
int tem = ;
for(int i = head[root]; i + ; i = edge[i].next){
int v = edge[i].to;
if(vis[v]) continue;
int pre = k;
get_dist(v, edge[i].w);
tem -= get_buf_sum(pre, k);
}
buf[k++] = ;
tem += get_buf_sum(, k);
ans += tem;
for(int i = head[root]; i + ; i = edge[i].next){
int v = edge[i].to;
if(vis[v]) continue;
cal(v);
}
} void solve(){
scanf("%d", &maxd);
memset(vis, , sizeof vis);
ans = ;
for(int i = ; i <= n; i++){
if(!vis[i]) cal(i);
}
printf("%d\n", ans);
} int main(){
//freopen("in.txt", "r", stdin);
while(~scanf("%d", &n)){
scanf("%d", &m);
init();
for(int i = , x, y, z; i < m; i++){
scanf("%d%d%d", &x, &y, &z);
addEdge(x, y, z);
addEdge(y, x, z);
getchar(), getchar();
}
solve();
}
return ;
}