题意:给你n个点,n-1条边,然后有m次查询,问u,v到两点距离相等的点有多少个。
分析:因为只有n-1条边,所以任意两点的路径只有唯一的一条,那么到两点距离相等的点,只有与这个路径的中点相连的那些点,如何求得中点呢?直接求的话肯定超时,我们可以把这个图转化成一个树,用一次dfs遍历建树,然后求得lca即可知道这条路径长度s,u,v两点中层数较深的那点往上走s/2步即为中点。
找到可中点如何求的结果呢?设u为层数较深的那个点,up(x,t)为点x向上走t步。首先路径长度为奇数的肯定无解,因为没有中点存在;u,v在同一层结果是n-son[up(u,s/2-1)] -son[up(v,s/2-1)];否则结果就是son[up(u,s/2)] - son[up(u,s/2-1)];最后有一个坑点那就是u == v,直接输出n。
开始不会写这题,后来参考了别人的代码,学会了用倍增法求lca。
看了这个帖才懂的倍增法:http://www.tuicool.com/articles/N7jQV32
AC代码:
#include <algorithm> #include <iostream> #include <cstdio> #include <vector> #include <stack> #include <queue> using namespace std; const int maxn = 100007; const int N = 17; struct tree{ int p[maxn][N]; int d[maxn]; int son[maxn]; void dfs(int u); void init(); int up(int x, int t); int lca(int u, int v); int result(int u, int v);//结果 }; vector<int> list[maxn]; tree tr; int n,m; void tree::dfs(int u){//dfs建树 for(int i = 0; i < list[u].size(); i++){ int v = list[u][i]; if(v != p[u][0]){ d[v] = d[u] + 1; p[v][0] = u; for(int j = 1; j < N; j++){ p[v][j] = p[p[v][j-1]][j-1]; } dfs(v); son[u] += son[v]; } } } void tree::init(){ for(int i = 1; i <= n; i++){ son[i] = 1; } for(int j = 0; j < N; j++){ p[1][j] = 1; } d[1] = 1; dfs(1); } int tree::up(int x, int t){//点x往上走t步 for(int i = 0; i < N; i++) if(t&(1<<i)) x = p[x][i]; return x; } int tree::lca(int u, int v){// if(d[u] > d[v]) u = up(u,d[u]-d[v]); if(u == v) return u; for(int i = N-1; i >= 0; i--) if(p[u][i] != p[v][i]) u = p[u][i], v = p[v][i]; return p[u][0]; } int tree::result(int u, int v){ if(u == v) return n; else{ if(d[u] < d[v]) swap(u,v); int x = lca(u,v); int s = d[u]-d[x] + d[v]-d[x];//距离 if(s&1) return 0;//无解 if(d[u] == d[v]){ u = up(u,s/2-1); v = up(v,s/2-1); return n - son[u] - son[v]; } else{ u = up(u,s/2-1); return son[p[u][0]] - son[u]; } } } void input(){ scanf("%d",&n); int u,v; for(int i = 0; i < n-1; i++){ scanf("%d%d",&u,&v); list[u].push_back(v); list[v].push_back(u); } } void solve(){ tr.init(); scanf("%d",&m); int u,v; for(int i = 0; i < m; i++){ scanf("%d%d",&u,&v); printf("%d\n",tr.result(u,v)); } } int main(){ input(); solve(); return 0; }