CodeForces 519E A and B and Lecture Rooms

时间:2022-03-20 20:52:17

题意:给你n个点,n-1条边,然后有m次查询,问u,v到两点距离相等的点有多少个。


分析:因为只有n-1条边,所以任意两点的路径只有唯一的一条,那么到两点距离相等的点,只有与这个路径的中点相连的那些点,如何求得中点呢?直接求的话肯定超时,我们可以把这个图转化成一个树,用一次dfs遍历建树,然后求得lca即可知道这条路径长度s,u,v两点中层数较深的那点往上走s/2步即为中点。

找到可中点如何求的结果呢?设u为层数较深的那个点,up(x,t)为点x向上走t步。首先路径长度为奇数的肯定无解,因为没有中点存在;u,v在同一层结果是n-son[up(u,s/2-1)] -son[up(v,s/2-1)];否则结果就是son[up(u,s/2)] - son[up(u,s/2-1)];最后有一个坑点那就是u == v,直接输出n。

开始不会写这题,后来参考了别人的代码,学会了用倍增法求lca。

看了这个帖才懂的倍增法:http://www.tuicool.com/articles/N7jQV32


AC代码:

#include <algorithm>
#include <iostream>
#include <cstdio>
#include <vector>
#include <stack>
#include <queue>

using namespace std;

const int maxn = 100007;
const int N = 17;

struct tree{
	int p[maxn][N];
	int d[maxn];
	int son[maxn];
	
	void dfs(int u);
	void init();
	int up(int x, int t);
	int lca(int u, int v);
	int result(int u, int v);//结果 
};

vector<int> list[maxn];
tree tr;
int n,m;


void tree::dfs(int u){//dfs建树 
	for(int i = 0; i < list[u].size(); i++){
		int v = list[u][i];
		if(v != p[u][0]){ 
			d[v] = d[u] + 1;
			p[v][0] = u;
			for(int j = 1; j < N; j++){
				p[v][j] = p[p[v][j-1]][j-1];
			}
			dfs(v);
			son[u] += son[v];
		}
	}
}

void tree::init(){
	for(int i = 1; i <= n; i++){
		son[i] = 1;		
	}
	for(int j = 0; j < N; j++){
			p[1][j] = 1;
	}
	d[1] = 1;
	dfs(1); 
}

int tree::up(int x, int t){//点x往上走t步 
	for(int i = 0; i < N; i++)
		if(t&(1<<i)) x = p[x][i];
	return x;
}

int tree::lca(int u, int v){//
	if(d[u] > d[v]) u = up(u,d[u]-d[v]);
	if(u == v) return u;
	for(int i = N-1; i >= 0; i--)
		if(p[u][i] != p[v][i]) u = p[u][i], v = p[v][i];
	
	return p[u][0];
}

int tree::result(int u, int v){
	if(u == v) return n;
	else{
		if(d[u] < d[v]) swap(u,v);
		int x = lca(u,v);
		int s = d[u]-d[x] + d[v]-d[x];//距离 
		
		if(s&1) return 0;//无解 
		
		if(d[u] == d[v]){
			u = up(u,s/2-1);
			v = up(v,s/2-1);
			return n - son[u] - son[v];
		}
		else{
			u = up(u,s/2-1);
			return son[p[u][0]] - son[u];
		}
	}
}


void input(){
	scanf("%d",&n);
	int u,v;
	for(int i = 0; i < n-1; i++){
		scanf("%d%d",&u,&v);
		list[u].push_back(v);
		list[v].push_back(u);
	}
}

void solve(){
	tr.init();
	scanf("%d",&m);
	int u,v;
	for(int i = 0; i < m; i++){
		scanf("%d%d",&u,&v);
		printf("%d\n",tr.result(u,v));
	}
}

int main(){
	input();
	solve();
	return 0;
}