题目链接:
题目大意:
给出一棵树和m次询问,每次询问给出两个点,求出到这两个点距离相等的点的个数。
题目分析:
- 设dp[u]记录每个点的深度,fa[u][i]代表比u深度小
2i 的父亲节点。siz[u]记录u为根的子树的大小。 - 我们dfs一遍处理出那些数组,然后在线求取lca,查询分为两种情况
- 如果其中一个点是这次查询两个点的公共祖先,那么如果他们之间的距离如果是奇数,那么无解,如果是偶数,那么到两个点距离相等的点的儿子中除了查询点的路径上的儿子的子树上所有的点都是符合条件的。
- 否则,看到两个查询点距离相等的点x是不是他们的公共祖先,如果是的话,那么就是除了查询点之间路径上的点所在子树的其他所有点,如果不是他们的公共祖先的话,那么就是路径相等的那个点子树中所有的点减去深度较深的查询点的子树中的点剩余点的数目。
AC代码:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>
#define MAX 100007
using namespace std;
int fa[MAX][25];
int dp[MAX];
int siz[MAX];
int n,m;
vector<int> e[MAX];
void add ( int u , int v )
{
e[u].push_back ( v );
e[v].push_back ( u );
}
void dfs ( int u , int p )
{
if ( p == -1 ) dp[u] = 1;
siz[u] = 1;
for ( int i = 1 ; i < 21 ; i++ )
fa[u][i] = fa[fa[u][i-1]][i-1];
for ( int i = 0 ; i < e[u].size() ; i++ )
{
int v = e[u][i];
if ( v == p ) continue;
dp[v] = dp[u]+1;
fa[v][0] = u;
dfs ( v , u );
siz[u] += siz[v];
}
}
int lca ( int u , int v )
{
if ( dp[u] < dp[v] )
swap ( u , v );
for ( int i = 20 ; i >= 0 ; i-- )
{
if ( dp[fa[u][i]] >= dp[v] )
u = fa[u][i];
if ( u == v )
return u;
}
for ( int i = 20 ; i >= 0 ; i-- )
if ( fa[u][i] != fa[v][i] )
{
v = fa[v][i];
u = fa[u][i];
}
return fa[u][0];
}
void init ( )
{
for ( int i = 0 ; i < MAX ; i++ )
e[i].clear();
}
int main ( )
{
while ( ~scanf ("%d" , &n ))
{
init ();
for ( int i = 1 ; i < n ; i++ )
{
int x,y;
scanf ( "%d%d" , &x , &y );
add ( x , y );
}
dfs ( 1 , -1 );
scanf ("%d" , &m );
while ( m-- )
{
int x,y,ans;
scanf ( "%d%d" , &x , &y );
if ( x == y )
{
printf ("%d\n" , n );
continue;
}
int f = lca ( x , y );
if ( dp[x] > dp[y] ) swap ( x , y );
if ( f == x )
{
int l = dp[y] - dp[x];
if ( l&1 ) ans = 0;
else
{
l /= 2;
l--;
int fy = y,fx;
int num = 0;
while ( l )
{
if ( l&1 ) fy = fa[fy][num];
num++;
l >>= 1;
}
fx = fa[fy][0];
ans = siz[fx] - siz[fy];
}
}
else
{
int l = dp[x]+dp[y]-2*dp[f];
if ( l&1 ) ans = 0;
else
{
l /= 2;
l--;
int ll = l;
int fy = y ,fx;
int num = 0;
while ( l )
{
if ( l&1 ) fy = fa[fy][num];
num++;
l >>= 1;
}
if ( dp[x] == dp[y] )
{
fx = x;
int num = 0;
while ( ll )
{
if ( ll&1 ) fx = fa[fx][num];
num++;
ll >>= 1;
}
ans = n - siz[fx] - siz[fy];
}
else
{
fx = fa[fy][0];
ans = siz[fx] - siz[fy];
}
}
}
printf ( "%d\n" , ans );
}
}
}