Codeforces 418d Big Problems for Organizers [树形dp][倍增lca]

时间:2023-03-09 03:44:56
Codeforces 418d Big Problems for Organizers [树形dp][倍增lca]

题意:

给你一棵有n个节点的树,树的边权都是1.

有m次询问,每次询问输出树上所有节点离其较近结点距离的最大值。

思路:

1.首先是按照常规树形dp的思路维护一个子树节点中距离该点的最大值son_dis[i],维护非子树节点中距离该点的最大值fa_dis[i];

2.对于每个节点维护它最大的三个儿子节点的son_dis;

3.维护up[i][j]和down[i][j]数组,这个类似倍增lca里边的fa[i][j],up[i][j]代表的含义是从第j个点向上到它的第2^i个父节点这条链上的点除了该节点所在子树外的距离的最大值。down[i][j]同理,但是维护的是从第2^i父节点到该点的链上除了该节点所在子树外的距离的最大值。在这里尤其注意的是,采取了类似差分的思想。看巨巨代码的时候我想了好一会。到这里预处理完毕。

4.对于给定的两个节点。假设a为深度较深的,b为深度浅的。

对于节点a,a到a的子树中所有的点肯定较近,所以son_dis[a]有可能是答案。a到a和b的中点的那条链上距离的最大值也有可能是答案。

对于b

假设b不是公共祖先,那么son_dis[b]有可能是答案。b到中点的链上的距离的最大值也有可能是答案。

若b是公共祖先,那么只有b到中点的链上的距离的最大值也有可能是答案。

对于最近公共祖先r

r的不包含a和b的子树的dis_son有可能是答案,r的fa_dis[r]有可能是答案。

最终结果是在有可能的答案中找最大值。

代码越改越挫。

#include<bits/stdc++.h>
#define MAXN 100050
#define MAXM 200050
using namespace std;
const int inf=0x3f3f3f3f;
struct st{
int num,id;
};
bool operator < (const st &a,const st &b){
return a.num>b.num;
}
multiset<st>my_set[MAXN];
struct edge{
int id;
edge *next;
};
int ednum;
edge edges[MAXM];
edge *adj[MAXN];
int dep[MAXN],son_dis[MAXN],fa_dis[MAXN],max_num[MAXN],father[MAXN],max_x[MAXN],rt[][MAXN],siz[MAXN],up[][MAXN],down[][MAXN];
bool vis[MAXN];
inline void addedge(int a,int b){
edge *tmp=&edges[ednum++];
tmp->id=b;
tmp->next=adj[a];
adj[a]=tmp;
}
void dfs(int pos,int deep){
dep[pos]=deep;
siz[pos]=;
int mmax=-;
for(edge *it=adj[pos];it;it=it->next){
if(dep[it->id]==){
father[it->id]=pos;
rt[][it->id]=pos;
dfs(it->id,deep+);
st tmp;
tmp.id=it->id;
tmp.num=son_dis[it->id];
my_set[pos].insert(tmp);
mmax=max(mmax,son_dis[it->id]);
son_dis[pos]=max(son_dis[pos],son_dis[it->id]+);
siz[pos]+=siz[it->id];
}
}
int num=;
for(edge *it=adj[pos];it;it=it->next){
if(father[it->id]==pos&&son_dis[it->id]==mmax)num++;
}
max_num[pos]=num;
max_x[pos]=mmax+;
}
void dfs2(int pos){
fa_dis[pos]=fa_dis[father[pos]]+;
if(max_num[father[pos]]>||son_dis[pos]+!=max_x[father[pos]]){
fa_dis[pos]=max(fa_dis[pos],max_x[father[pos]]+);
up[][pos]=max_x[father[pos]]-dep[father[pos]];
down[][pos]=max_x[father[pos]]+dep[father[pos]];
}
else{
int maxx=-;
for(edge *it=adj[father[pos]];it;it=it->next){
if(father[it->id]==father[pos]&&(it->id!=pos)){
maxx=max(maxx,son_dis[it->id]);
}
}
fa_dis[pos]=max(fa_dis[pos],maxx+);
if(maxx==-)maxx=-;
maxx++;
up[][pos]=maxx-dep[father[pos]];
down[][pos]=maxx+dep[father[pos]];
}
for(edge *it=adj[pos];it;it=it->next){
if(father[it->id]==pos)dfs2(it->id);
}
}
void prelca(int n){
up[][]=down[][]=-inf;
for(int i=;i<=;i++){
for(int j=;j<=n;j++){
rt[i][j]=rt[i-][j]==-?-:rt[i-][rt[i-][j]];
up[i][j]=max(up[i-][j],up[i-][rt[i-][j]]);
down[i][j]=max(down[i-][j],down[i-][rt[i-][j]]);
}
}
}
int LCA(int u,int v){//查询u和v的lca
if(dep[u]<dep[v])swap(u,v);
for(int i=;i<;i++){
if((dep[u]-dep[v])>>i&){
u=rt[i][u];
}
}
if(u==v)return u;
for(int i=;i>=;i--){
if(rt[i][u]!=rt[i][v]){
u=rt[i][u];
v=rt[i][v];
}
}
return rt[][u];
}
int jump(int &pos,int num,int tmp[][MAXN]){//查询节点pos的第num个父亲
int rel=-inf;
for(int i=;i<;i++){
if(num>>i&){
rel=max(rel,tmp[i][pos]);
pos=rt[i][pos];
}
}
return rel;
}
void solve(int a,int b){
int r=LCA(a,b);
if(dep[a]<dep[b])swap(a,b);
int maxa,maxb,maxc,maxd,half,v,w,ar,br;
maxa=maxb=maxc=maxd=;
ar=dep[a]-dep[r];
br=dep[b]-dep[r];
v=a;w=b;
half=min((ar+br)/,ar-);
maxa=max(son_dis[a],jump(v,half,up)+dep[a]);
maxb=jump(v,ar-half-,down)-dep[r]+br;
maxc=-inf;
maxd=fa_dis[r]+min(ar,br);
if(r!=b){
maxb=max(maxb,son_dis[b]);
maxb=max(maxb,jump(w,br-,up)+dep[b]);
set<st>::const_iterator it=my_set[r].begin();
for(int i=;i<=min((int)my_set[r].size(),);i++){
if(it->id!=v&&it->id!=w){
maxc=it->num++min(dep[a],dep[b])-dep[r];
break;
}
it++;
}
}
else{
set<st>::const_iterator it=my_set[r].begin();
for(int i=;i<=min((int)my_set[r].size(),);i++){
if(it->id!=v){
maxc=it->num+;
break;
}
it++;
}
}
printf("%d\n",max(max(maxa,maxb),max(maxc,maxd)));
}
int main(){
int n;
scanf("%d",&n);
for(int i=;i<n;i++){
int a,b;
scanf("%d%d",&a,&b);
addedge(a,b);
addedge(b,a);
}
int m;
memset(rt,-,sizeof(rt));
dfs(,);
for(edge *it=adj[];it;it=it->next){
if(father[it->id]==){
dfs2(it->id);
}
}
prelca(n);
scanf("%d",&m);
for(int i=;i<=m;i++){
int a,b;
scanf("%d%d",&a,&b);
solve(a,b);
}
}