poj_1741——树的分治

时间:2022-10-04 04:24:14
看了网上各种大神的树的分治的模板,然后自己敲了一个。。。直接上代码了,晚上再写一个学习笔记, 丧心病狂的poj,上次一直跪在vector上,这次觉得不用vector写了。。。
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <vector>
#include <map>
#include <cstring>
#define ll long long
#define INF 1<<30

using namespace std;

const int N = 10000+5;
int cnt,s_max,rt,ans,n,k,num,list[N];///list表装的是整棵子树的节点
int size,ss[N];///size表示的是分出来的树的总结点数,ss表示该每颗子树的最大节点数
int son[N],d[N],head[N<<2]; ///保存节点子树的最大节点值
bool vis[N];///vis表示删除的节点

struct edge{
    int to,next;
    int w;
}e[N<<2];

void add_edge(int u,int v,int w){
    e[cnt].to = v;
    e[cnt].w = w;
    e[cnt].next = head[u];
    head[u] = cnt++;
}

///找重心
void dfs_0(int u,int fa){
    int tem = -1;
    son[u] = 1;
    list[size++] = u;
    for(int i = head[u];i != -1;i = e[i].next){
        int v = e[i].to;
        if(v == fa || vis[v]) continue;
        dfs_0(v,u);
        son[u] += son[v];
        tem = max(son[v],tem);
    }
    ss[u] = tem;
}

void getroot(int u,int fa){
    dfs_0(u,fa);
    s_max = INF;
    for(int i=0;i<size;i++){
        int tem = max(ss[list[i]],size-son[list[i]]);
        if(tem<s_max){
            rt = list[i];
            s_max = tem;
        }
    }
}

///求距离根的距离
void dfs(int u,int dis,int fa){
    d[num++] = dis;
    for(int i = head[u];i!=-1;i=e[i].next){
        int v = e[i].to;
        if(v != fa && !vis[v]){
            dfs(v,dis+e[i].w,u);
        }
    }
}

///计算节点对
int calc(int u,int dis){
    int ret = 0;
    num = 0;
    dfs(u,dis,-1);
    sort(d,d+num);
    int i = 0, j = num - 1;
    ///排序后统计点对的个数,o(n)的复杂度
    while(i < j)
    {
        while(d[i] + d[j] > k && i < j) j--;
        ret += j - i;
        i++;
    }
    return ret;
}

void solve(int u){
    size = 0;
    getroot(u,-1);
    ans += calc(rt,0);
    vis[rt] = true;
    for(int i = head[rt];i != -1;i=e[i].next){
        int v = e[i].to;
        if(!vis[v]){
            ans -= calc(v,e[i].w);
            solve(v);
        }
    }
}

void init(){
    memset(vis,false,sizeof(vis));
    memset(head,-1,sizeof(head));
    ans = 0;
    cnt = 0;
}

int main()
{
    freopen("test.in","r",stdin);
    while(scanf("%d%d",&n,&k)!=EOF && (n || k)){
        int x,y,dis;
        init();
        for(int i=0;i<n-1;i++){
            scanf("%d%d%d",&x,&y,&dis);
            add_edge(x,y,dis);
            add_edge(y,x,dis);
        }
        solve(1);
        printf("%d\n",ans);
    }
    return 0;
}