POJ 1741 树的点分治

时间:2022-10-15 04:23:13

题目链接

题意:

求一颗树上距离<=k的点对数。

思路:

由一个根向下求得距离后,将距离经过排序可以O(n)时间算的以该点为根的子树下距离<=k的点对数。

点对存在两种情况:

1.两点来自不同子树

2.两点来自相同子树

不难发现如果我们递归向下求解子树中的点对数时,对于两点来自相同子树的点对我们会重复计算,所以我们进行递归求解子树时需要先将重复计算减去。

如果按以上操作进行并提交,你将会得到一个TLE。为什么呢?

求解距离时我们会遍历整颗子树的所有点即复杂度为O(sizes[v])

当树是一条链状时,那么时间复杂度将会高达O(n^2)

既然时间复杂度是O(sizes[v])那么我们将O(sizes[v])变得尽量小不就可以优化了嘛

分解子树的点数尽量少-->树的重心||链状树重心分割后子树点数最大为sizes[u]/2->时间复杂度降到nlog2(n)

C++代码:

#include<map>
#include<set>
#include<stack>
#include<cmath>
#include<queue>
#include<vector>
#include<string>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn = 10010;
const int maxm = 20010;

int n,k,tol,head[maxn];
struct edge
{
    int to,next,cost;
}es[maxm];

void addedge( int u , int v , int w )
{
    es[tol].to = v;
    es[tol].cost = w;
    es[tol].next = head[u];
    head[u] = tol++;
}

int Size,Root,vis[maxn],sizes[maxn],son[maxn];

void GetRoot( int u , int f )
{
    sizes[u] = 1; son[u] = 0;
    for ( int i=head[u] ; i!=-1 ; i=es[i].next )
    {
        int v = es[i].to,w = es[i].cost;
        if ( v!=f&&!vis[v] )
        {
            GetRoot( v , u );
            sizes[u] += sizes[v];
            son[u] = max ( son[u] , sizes[v] );
        }
    }
    son[u] = max( son[u] , Size-sizes[u] );
    if ( Root==-1||son[u]<son[Root] ) Root = u;
}

int ans,que[maxn],len,dis[maxn];

void GetDep( int u , int f )
{
    que[len++] = dis[u];
    for ( int i=head[u] ; i!=-1 ; i=es[i].next )
    {
        int v = es[i].to,w = es[i].cost;
        if ( v!=f&&!vis[v] )
        {
            dis[v] = dis[u]+w;
            GetDep( v , u );
        }
    }
}

int cal( int u , int c )
{
    len = 0; dis[u] = c;
    GetDep( u , 0 );
    sort( que , que+len );
    int res =0;
    for ( int i=0,j=len-1 ; i<j ; )
        if ( que[i]+que[j]<=k ) res += j-i,i++;
        else j--;
    return res;
}

void slove( int u )
{
    ans += cal( u , 0 );
    vis[u] = 1;
    for ( int i=head[u] ; i!=-1 ; i=es[i].next )
    {
        int v = es[i].to,w = es[i].cost;
        if ( !vis[v] )
        {
            ans -= cal( v , w );
            Size = sizes[v]; Root = -1;
            GetRoot( v , 0 );
            slove( Root );
        }
    }
}

int main()
{
    while( scanf ( "%d%d" , &n , &k )==2 )
    {
        if ( n==0&&k==0 ) break;
        tol = 0;
        memset( head , -1 , sizeof(vis) );
        for ( int i=1 ; i<n ; i++ )
        {
            int u,v,w;
            scanf ( "%d%d%d" , &u , &v , &w );
            addedge ( u , v , w );
            addedge ( v , u , w );
        }
        memset( vis , 0 , sizeof(vis) );
        Size = n; Root = -1;
        GetRoot( 1 , 0 );
        ans = 0;
        slove( Root );
        printf ( "%d\n" , ans );
    }
    return 0;
}