POJ1741 Tree(点分治模板题)

时间:2022-04-29 04:24:38

POJ - 1741

参考题解:http://blog.csdn.net/yang_7_46/article/details/9966455

题意:给一棵边带权树,问两点之间的距离小于等于K的点对有多少个。

题解:

POJ1741 Tree(点分治模板题)

!对于一棵有根数,符合条件的路径分为两类
!!经过root
!!未经过root(存在于子树)
!因此可以用分治的思想,先处理完经过根节点的,再处理子树
!对于一棵树,如何快速查找经过root的符合条件的路径数?
!!查找重心(O(n))
!!处理出所有节点到重心的距离(O(n))
!!sort所有距离之后O(n)求出满足条件的节点对(O(nlogn))
!!扣除存在于子树的情况

#include<cstdio>
#include<vector>
#include<algorithm>
#include<cstring>
using namespace std;
#define mp make_pair

const int N=10005;
vector<pair<int,int> > G[N];
vector<int> dep;
int n,k,root,size,ans;
int f[N],s[N],d[N];
bool done[N];

void find(int u,int fa) {
    f[u]=0;s[u]=1;
    int sz=G[u].size();
    for(int i=0;i<sz;++i) {
        int v=G[u][i].first;
        int w=G[u][i].second;
        if(done[v]||v==fa) continue;
        find(v,u);
        f[u]=max(f[u],s[v]);
        s[u]+=s[v];
    }
    f[u]=max(f[u],size-s[u]);
    if(f[u]<f[root]) root=u;
}

void getDep(int u,int fa) {
    s[u]=1;dep.push_back(d[u]);
    int sz=G[u].size();
    for(int i=0;i<sz;++i) {
        int v=G[u][i].first;
        int w=G[u][i].second;
        if(done[v]||v==fa) continue;
        d[v]=d[u]+w;
        getDep(v,u);
        s[u]+=s[v];
    }
}

int calc(int u,int fa,int init) {
    d[u]=init;
    dep.clear();
    getDep(u,fa);
    sort(dep.begin(),dep.end());
    int ret=0;
    for(int l=0,r=dep.size()-1;l<r;) {
        if(dep[l]+dep[r]<=k) {
            ret+=r-l++;
        } else {
            r--;
        }
    }
    return ret;
}

void solve(int u,int fa) {
    done[u]=1;
    ans+=calc(u,fa,0);
    int sz=G[u].size();
    for(int i=0;i<sz;++i) {
        int v=G[u][i].first;
        int w=G[u][i].second;
        if(done[v]||v==fa) continue;
        ans-=calc(v,u,w);
        f[0]=size=s[v];
        find(v,root=0);
        solve(root,0);
    }
}

int main() {
    while(~scanf("%d%d",&n,&k)) {
        if(n==0&&k==0) break;
        for(int i=0;i<=n;++i) G[i].clear();
        int u,v,w;
        for(int i=1;i<n;++i) {
            scanf("%d%d%d",&u,&v,&w);
            G[u].push_back(mp(v,w));
            G[v].push_back(mp(u,w));
        }
        memset(done,0,sizeof(done));
        f[0]=size=n;
        find(1,root=0);
        ans=0;
        solve(root,0);
        printf("%d\n",ans);
    }
    return 0;
}