给一个树, 然后每次询问给出2个点, 让你求出到这两个点的距离相等的点的距离。
将1当成根, dfs一遍求出每个点的高度
1 如果两个点相等, 那么答案显然是n。
2. 如果两个点的距离是奇数, 那么显然无解。
3. 如果两个点高度相等, 那么我们找到他们两个的lca, 显然这两个点到他们lca这个点, 这两条链上的任意一个点都不满足要求。 其他的点全都满足要求。
4 如果高度不等, 需要找到这两个点的中点, 然后这两个点到他们中点的链上任意一个点都不满足要求。
具体实现看代码
#include <iostream> #include <vector> #include <cstdio> #include <cstring> #include <algorithm> #include <complex> #include <cmath> #include <map> #include <set> #include <string> #include <queue> #include <stack> #include <bitset> using namespace std; #define pb(x) push_back(x) #define ll long long #define mk(x, y) make_pair(x, y) #define lson l, m, rt<<1 #define mem(a) memset(a, 0, sizeof(a)) #define rson m+1, r, rt<<1|1 #define mem1(a) memset(a, -1, sizeof(a)) #define mem2(a) memset(a, 0x3f, sizeof(a)) #define rep(i, n, a) for(int i = a; i<n; i++) #define fi first #define se second typedef complex <double> cmx; typedef pair<int, int> pll; const double PI = acos(-1.0); const double eps = 1e-8; const int mod = 1e9+7; const int inf = 1061109567; const int dir[][2] = { {-1, 0}, {1, 0}, {0, -1}, {0, 1} }; const int maxn = 1e5+5; int head[maxn], num, n, f[maxn][18], son[maxn], d[maxn]; struct node { int to, nextt; }e[maxn*2]; void add(int u, int v) { e[num].to = v, e[num].nextt = head[u], head[u] = num++; } void init() { num = 0; mem1(head); mem1(f); } void dfs(int u, int fa) { son[u] = 1; for(int i = head[u]; ~i; i = e[i].nextt) { int v = e[i].to; if(v == fa) continue; f[v][0] = u; d[v] = d[u]+1; dfs(v, u); son[u] += son[v]; } } void lcaInit() { for(int j = 1; j < 18; j++) { for(int i = 1; i <= n; i++) { if(f[i][j-1] != -1) f[i][j] = f[f[i][j-1]][j-1]; } } } int lca(int u, int v) { int i, j; if(d[u] < d[v]) swap(u, v); for(i = 0; (1<<i) <= d[u]; i++); i--; for(j = i; j >= 0; j--) { if(d[u]-(1<<j) >= d[v]) u = f[u][j]; } if(u == v) return u; for(j = i; j >= 0; j--) { if(f[u][j] != -1 && f[v][j] != f[u][j]) { u = f[u][j]; v = f[v][j]; } } if(u != v) u = f[u][0]; return u; } int getNode(int u, int k) { int len = 1; while((1<<len) <= k) len++; len--; for(int i = len; i >= 0; i--) { if(k >= (1<<i)) { k -= (1<<i); u = f[u][i]; } } return u; } int main() { cin>>n; int u, v, m; init(); for(int i = 0; i < n - 1; i++) { scanf("%d%d", &u, &v); add(u, v); add(v, u); } dfs(1, 0); lcaInit(); cin>>m; while(m--) { scanf("%d%d", &u, &v); if(u == v) { printf("%d\n", n); continue; } int dis = d[u]+d[v]-2*d[lca(u, v)]; if(dis%2) { puts("0"); continue; } if(d[u] == d[v]) { u = getNode(u, dis/2-1); v = getNode(v, dis/2-1); printf("%d\n", n-son[u]-son[v]); } else { if(d[v]>d[u]) swap(u, v); v = getNode(u, dis/2-1); u = getNode(u, dis/2); printf("%d\n", son[u]-son[v]); } } return 0; }