【POJ1741】树中点对统计 点分治

时间:2021-11-22 15:48:30

题目描述

  给定一棵N(1<=N<=100000)个结点的带权树,每条边都有一个权值(为正整数,小于等于1001)。定义dis(u,v)为u,v两点间的最短路径长度,路径的长度定义为路径上所有边的权和。再给定一个K(1<=K<=10^9),如果对于不同的两个结点u,v,如果满足dist(u,v)<=K,则称(u,v)为合法点对。求合法点对个数。

题目大意

  求树中距离小于k的点对个数

数据范围

对于50%的数据,n<=1000,k<=1000;
对于100%的数据,n<=100000,k<=10^9;

样例输入

5 4
1 2 3
1 3 1
1 4 2
3 5 1

样例输出

8

解题思路


不写了233

代码

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <ctime>
#define Maxn 100005
using namespace std;
inline int Getint(){int x=0,f=1;char ch=getchar();while('0'>ch||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while('0'<=ch&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
int fa[Maxn],h[Maxn],size[Maxn],dep[Maxn];
bool vis[Maxn];
int n,k,cnt=0,L,r,Min;
struct node{int to,next,v;}e[Maxn*2];
void AddEdge(int x,int y,int v){e[++cnt]=(node){y,h[x],v};h[x]=cnt;}
void Init(){
    n=Getint(),k=Getint();
    for(int i=1;i<n;i++){
        int x,y,v;
        x=Getint(),y=Getint(),v=Getint();
        fa[y]=x;
        AddEdge(x,y,v);
        AddEdge(y,x,v);
    }
    memset(vis,0,sizeof(vis));
}
int dfssize(int u,int pre){
    size[u]=1;
    for(int p=h[u];p;p=e[p].next){
        int y=e[p].to;
        if(vis[y]||y==pre)continue;
        size[u]+=dfssize(y,u);
    }
    return size[u];
}
void Getroot(int u,int pre,int tot,int &root){
    int Max=tot-size[u];
    for(int p=h[u];p;p=e[p].next){
        int y=e[p].to;
        if(vis[y]||y==pre)continue;
        Getroot(y,u,tot,root);
        Max=max(Max,size[y]);
    }
    if(Max<Min){
        Min=Max;
        root=u;
    }
}
void Getlen(int u,int pre,int d){
    dep[r++]=d;
    for(int p=h[u];p;p=e[p].next){
        int y=e[p].to;
        if(vis[y]||y==pre)continue;
        Getlen(y,u,d+e[p].v);
    }
}
int Calc(int L,int r){
    sort(dep+L,dep+r);
    int ret=0,Pos=r-1;
    for(int i=L;i<r;i++){
        if(dep[i]>k)break;
        while(Pos>=L&&dep[i]+dep[Pos]>k)Pos--;
        ret+=Pos-L+1;
        if(Pos>i)ret--;
    }
    return ret/2;
}
int Solve(int u){
    int tot=dfssize(u,0),ret=0,root;
    Min=0x7fffffff;
    Getroot(u,0,tot,root);
    vis[root]=true;
    for(int p=h[root];p;p=e[p].next){
        int y=e[p].to;
        if(vis[y])continue;
        ret+=Solve(y);
    }
    L=r=0;
    for(int p=h[root];p;p=e[p].next){
        int y=e[p].to;
        if(vis[y])continue;
        Getlen(y,root,e[p].v);
        ret-=Calc(L,r);
        L=r;
    }
    ret+=Calc(0,r);
    for(int i=0;i<r;i++)
        if(dep[i]<=k)ret++;
        else break;
    vis[root]=false;
    return ret;
}
int main(){
    Init();
    cout<<Solve(1)<<"\n";
}