[DP优化方法之虚树]

时间:2023-03-09 00:14:33
[DP优化方法之虚树]

首先我们看一篇文章 转自xyz:

给出一棵树.

每次询问选择一些点,求一些东西.这些东西的特点是,许多未选择的点可以通过某种方式剔除而不影响最终结果.

于是就有了建虚树这个技巧.....

我们可以用log级别的时间求出点对间的lca....

那么,对于每个询问我们根据原树的信息重新建树,这棵树中要尽量少地包含未选择节点. 这棵树就叫做虚树.

接下来所说的"树"均指虚树,原来那棵树叫做"原树".

构建过程如下:

按照原树的dfs序号(记为dfn)递增顺序遍历选择的节点. 每次遍历节点都把这个节点插到树上.

首先虚树一定要有一个根. 随便扯一个不会成为询问点的点作根.

维护一个栈,它表示在我们已经(用之前的那些点)构建完毕的虚树上,以最后一个插入的点为端点的DFS链.

设最后插入的点为p(就是栈顶的点),当前遍历到的点为x.我们想把x插入到我们已经构建的树上去.

求出lca(p,x),记为lca.有两种情况:

  1.p和x分立在lca的两棵子树下.

  2.lca是p.

  (为什么lca不能是x?

   因为如果lca是x,说明dfn(lca)=dfn(x)<dfn(a),而我们是按照dfs序号遍历的,于是dfn(a)<dfn(x),矛盾.)

对于第二种情况,直接在栈中插入节点x即可,不要连接任何边(后面会说为什么).

对于第一种情况,要仔细分析.

我们是按照dfs序号遍历的(因为很重要所以多说几遍......),有dfn(x)>dfn(p)>dfn(lca).

这说明什么呢? 说明一件很重要的事:我们已经把lca所引领的子树中,p所在的子树全部遍历完了!

  简略的证明:如果没有遍历完,那么肯定有一个未加入的点h,满足dfn(h)<dfn(x),

        我们按照dfs序号递增顺序遍历的话,应该把h加进来了才能考虑x.

这样,我们就直接构建lca引领的,p所在的那个子树. 我们在退栈的时候构建子树.

p所在的子树如果还有其它部分,它一定在之前就构建好了(所有退栈的点都已经被正确地连入树中了),就剩那条链.

如何正确地把p到lca那部分连进去呢?

设栈顶的节点为p,栈顶第二个节点为q.

重复以下操作:

  如果dfn(q)>dfn(lca),可以直接连边q->p,然后退一次栈.

  如果dfn(q)=dfn(lca),说明q=lca,直接连边lca->p,此时子树已经构建完毕.

  如果dfn(q)<dfn(lca),说明lca被p与q夹在中间,此时连边lca->q,退一次栈,再把lca压入栈.此时子树构建完毕.

    如果不理解这样操作的缘由可以画画图.....

最后,为了维护dfs链,要把x压入栈. 整个过程就是这样.....

然后就是我自己的理解了 我觉得我的理解虽然不是很严谨但是很容易懂

其实说白了就是如果我找到一个点不在这条链上 然后我们就跳栈顶的点使得栈顶的点和第二栈顶的点夹着lca 当然有可能第二栈顶的点就是lca 每次跳的时候连边

然后弹掉栈顶的点 如果现在栈顶的点不是lca就把lca塞进去 不是now的点就把now塞进去(这个应该是怕同此询问有重复的点吧 我去掉也ac)

然后的话虚树解决的就是总询问点数很少 询问次数很多的题 然后后面的记得清空就好

top=0; S[++top]=1; Plen=0; P[++Plen]=1;
for(LL i=1;i<=K;i++)
{
LL now=H[i]; LL f=lca(S[top],now);
while(dfn[S[top-1]]>dfn[f]){ins(1,S[top-1],S[top],0); top--;}
if(dfn[S[top]]>dfn[f]){ins(1,f,S[top],0); top--;}
if(S[top]!=f) S[++top]=f,P[++Plen]=f;
S[++top]=now,P[++Plen]=now;
}
while(top>1){ins(1,S[top-1],S[top],0); top--;}

h数组是询问的点要按dfn序排一下

[Sdoi2011消耗战

这是一道模版题 找到所有点建完虚树后 然后dp 要删去一些边且费用最小 想一想 真正有用的也就只是这些点还有lca的点 所以的话dp一下 要不选下面点一直到根的最小值的和 要不就选lca到根最小值

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<cmath>
#include<vector>
#include<climits>
#define Maxn 250010
using namespace std;
typedef long long LL;
struct node{LL x,y,next,d;}edge[][Maxn*]; LL len[],first[][Maxn];
void ins(LL k,LL x,LL y,LL d){len[k]++; edge[k][len[k]].x=x; edge[k][len[k]].y=y; edge[k][len[k]].d=d; edge[k][len[k]].next=first[k][x]; first[k][x]=len[k];}
LL dep[Maxn],fa[Maxn][],minx[Maxn]; LL dfn[Maxn],id=; LL N,M;
void Dfs(LL x,LL f)
{
dfn[x]=++id;
for(LL k=first[][x];k!=-;k=edge[][k].next)
{
LL y=edge[][k].y;
if(y!=f){dep[y]=dep[x]+; fa[y][]=x; minx[y]=min(minx[x],edge[][k].d); Dfs(y,x);}
}
}
LL H[Maxn],K; LL top,S[Maxn];
bool Cmp(const LL &x,const LL &y){return dfn[x]<dfn[y];}
LL lca(LL x,LL y)
{
if(dep[x]<dep[y]) swap(x,y);
LL deep=dep[x]-dep[y];
for(LL i=;i>=;i--) if(deep>=(<<i)){deep-=(<<i); x=fa[x][i];}
if(x==y) return x;
for(LL i=;i>=;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][];
}
LL F[Maxn]; bool C[Maxn];
void DP(LL x)
{
F[x]=minx[x]; if(C[x]) return ; LL tmp=;
for(LL k=first[][x];k!=-;k=edge[][k].next)
{
LL y=edge[][k].y;
DP(y); tmp+=F[y];
}
if(tmp<F[x]) F[x]=tmp;
} LL P[Maxn],Plen;
void Solve()
{
for(LL i=;i<=Plen;i++) first[][P[i]]=-; len[]=;
scanf("%lld",&K); for(LL i=;i<=K;i++) scanf("%lld",&H[i]),C[H[i]]=;
sort(H+,H+K+,Cmp);
top=; S[++top]=; Plen=; P[++Plen]=;
for(LL i=;i<=K;i++)
{
LL now=H[i]; LL f=lca(S[top],now);
while(dfn[S[top-]]>dfn[f]){ins(,S[top-],S[top],); top--;}
if(dfn[S[top]]>dfn[f]){ins(,f,S[top],); top--;}
if(S[top]!=f) S[++top]=f,P[++Plen]=f;
if(S[top]!=now) S[++top]=now,P[++Plen]=now;
}
while(top>){ins(,S[top-],S[top],); top--;}
DP(); printf("%lld\n",F[]);
for(LL i=;i<=K;i++) C[H[i]]=;
}
int main()
{
scanf("%lld",&N); len[]=; memset(first[],-,sizeof(first[]));
for(LL i=;i<N;i++){LL x,y,d; scanf("%lld%lld%lld",&x,&y,&d); ins(,x,y,d); ins(,y,x,d);}
dep[]=; for(LL i=;i<=N;i++) minx[i]=LLONG_MAX; Dfs(,);
for(LL j=;j<=;j++)
for(LL i=;i<=N;i++)
fa[i][j]=fa[fa[i][j-]][j-];
scanf("%lld",&M); len[]=; memset(first[],-,sizeof(first[])); memset(C,,sizeof(C));
for(LL i=;i<=M;i++)
Solve();
return ;
}
[Hnoi2014]世界树

这一道题就比较劲了

一些点管辖整个树 这些点是给定的 首先我们建一颗虚树 然后因为虚树上有一些点是lca的 也就是空的 我们要把这些点dp一下看看最近去到哪里

然后的话对于虚树上每两个相临的节点 我们二分这两个节点的链 也就是原树上的链 然后找到中间点 切开之后分别属于那两边

但是我们忽略了一个地方 就是有一些点没有被找过 他们为那些询问点下面的 而且下面没有询问点了 那怎么办呢 这些点肯定是跟着上面祖先选什么我就选什么的 这样的话就在祖先那里统计一下扫过了多少个点 剩下没被扫过的就是祖先的

说起来容易打起来难,这道算是经典题 不做不算是会虚树

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<cmath>
#include<climits>
#define Maxn 300010
using namespace std;
const int inf=1e9;
struct node
{
int x,y,next,d;
}edge[][Maxn*]; int len[],first[][Maxn];
void ins(int k,int x,int y,int d){len[k]++; edge[k][len[k]].x=x; edge[k][len[k]].y=y; edge[k][len[k]].next=first[k][x]; first[k][x]=len[k];}
int N,Q; int deep[Maxn],size[Maxn],fa[][Maxn]; int dfn[Maxn],id=; int unc[Maxn];
bool Cmp(const int &x,const int &y){return dfn[x]<dfn[y];}
void Dfs(int x,int f)
{
size[x]=; dfn[x]=++id;
for(int k=first[][x];k!=-;k=edge[][k].next)
{
int y=edge[][k].y;
if(y!=f)
{
deep[y]=deep[x]+;
fa[][y]=x;
Dfs(y,x);
size[x]+=size[y];
}
}
}
int lca(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
int d=(deep[x]-deep[y]);
for(int i=;i>=;i--) if((<<i)<=d) d-=(<<i),x=fa[i][x];
if(x==y) return x;
for(int i=;i>=;i--) if(fa[i][x]!=fa[i][y]) x=fa[i][x],y=fa[i][y];
return fa[][x];
}
int H[Maxn]; int P[Maxn],S[Maxn],top,plen; bool C[Maxn]; int dis(int x,int y){return deep[x]+deep[y]-*deep[lca(x,y)];} pair<int,int> F1[Maxn],F2[Maxn],G[Maxn];
void Dfs1(int x)
{
if(C[x]) F1[x]=F2[x]=make_pair(,x);
for(int k=first[][x];k!=-;k=edge[][k].next)
{
int y=edge[][k].y; Dfs1(y);
if(!C[x])
{
int D=dis(x,y);
if((F1[x].first>F1[y].first+D)||(F1[x].first==F1[y].first+D&&F1[x].second>F1[y].second)) F2[x]=F1[x],F1[x]=make_pair(F1[y].first+D,F1[y].second);
else if((F2[x].first>F1[y].first+D)||(F2[x].first==F1[y].first+D&&F2[x].second>F1[y].second)) F2[x]=make_pair(F1[y].first+D,F1[y].second);
}
}
}
int F[Maxn];
void Dfs2(int x,int f)
{
if(!C[x])
{
G[x].first=G[f].first+dis(x,f); G[x].second=G[f].second;
if(F1[f].second==F1[x].second)
{
if((F2[f].first+dis(f,x)<G[x].first)||(F2[f].first+dis(f,x)==G[x].first&&F2[f].second<G[x].second))
G[x].second=F2[f].second,G[x].first=F2[f].first+dis(f,x);
}
else
if((F1[f].first+dis(f,x)<G[x].first)||(F1[f].first+dis(f,x)==G[x].first&&F1[f].second<G[x].second))
G[x].second=F1[f].second,G[x].first=F1[f].first+dis(f,x);
}
else G[x]=make_pair(,x); if(C[x]) F[x]=x;
else
{
if(G[x].first<F1[x].first||(G[x].first==F1[x].first&&G[x].second<F1[x].second)) F[x]=G[x].second;
if(G[x].first>F1[x].first||(G[x].first==F1[x].first&&G[x].second>F1[x].second)) F[x]=F1[x].second;
}
for(int k=first[][x];k!=-;k=edge[][k].next)
{
int y=edge[][k].y;
Dfs2(y,x);
}
} int Find(int x,int D){for(int i=;i>=;i--) if(D>=(<<i)) D-=(<<i),x=fa[i][x]; return x;} int ans[Maxn];
void DP(int x)
{
ans[F[x]]++; unc[x]--;
for(int k=first[][x];k!=-;k=edge[][k].next)
{
int y=edge[][k].y; int L=Find(y,deep[y]-deep[x]-); int R=fa[][y]; int sizex=size[L]; unc[x]-=sizex; int ret=x;
if(F[x]!=F[y])
{
if(deep[L]<=deep[R])
{
while(deep[L]<=deep[R])
{
int mid=(deep[L]+deep[R])>>; int midx=Find(y,deep[y]-mid);
int disx=dis(F[x],midx); int disy=dis(F[y],midx);
if(disx>disy||(disx==disy&&F[x]>F[y])) R=Find(y,deep[y]-(mid-));
else if(disx<disy||(disx==disy&&F[x]<F[y])) L=Find(y,deep[y]-(mid+)),ret=midx;
}
ans[F[x]]+=size[Find(y,deep[y]-deep[x]-)]-size[Find(y,deep[y]-deep[ret]-)];
ans[F[y]]+=size[Find(y,deep[y]-deep[ret]-)]-size[y];
}
}
else ans[F[x]]+=size[Find(y,deep[y]-deep[x]-)]-size[y];
}
for(int k=first[][x];k!=-;k=edge[][k].next)
{
int y=edge[][k].y;
DP(y);
}
} int B[Maxn];
void Solve()
{
int K; scanf("%d",&K); for(int i=;i<=K;i++){scanf("%d",&H[i]); B[i]=H[i]; C[H[i]]=;}
sort(H+,H+K+,Cmp); top=; S[top]=; plen=; P[]=;
for(int i=;i<=K;i++)
{
int now=H[i]; int f=lca(now,S[top]);
while(dfn[S[top-]]>dfn[f]) ins(,S[top-],S[top],),top--;
if(dfn[S[top]]>dfn[f]) ins(,f,S[top],),top--;
if(S[top]!=f) S[++top]=f,P[++plen]=f;
if(S[top]!=now) S[++top]=now,P[++plen]=now;
}
while(top>) ins(,S[top-],S[top],),top--;
Dfs1();
Dfs2(,);
for(int i=;i<=plen;i++) unc[P[i]]=size[P[i]];
DP();
for(int i=;i<=plen;i++) ans[F[P[i]]]+=unc[P[i]];
for(int i=;i<=K;i++) printf("%d ",ans[B[i]]); printf("\n");
for(int i=;i<=K;i++) ans[H[i]]=;
for(int i=;i<=K;i++) C[H[i]]=;
for(int i=;i<=plen;i++) F1[P[i]].first=F1[P[i]].second=F2[P[i]].first=F2[P[i]].second=G[P[i]].first=G[P[i]].second=F[P[i]]=inf,first[][P[i]]=-;
len[]=; }
int main()
{
scanf("%d",&N); len[]=; memset(first[],-,sizeof(first[]));
for(int i=;i<N;i++){int x,y; scanf("%d%d",&x,&y); ins(,x,y,); ins(,y,x,);}
Dfs(,); for(int i=;i<=;i++) for(int j=;j<=N;j++) fa[i][j]=fa[i-][fa[i-][j]];
memset(first[],-,sizeof(first[])); len[]=;
for(int i=;i<=N;i++) F1[i].first=F1[i].second=F2[i].first=F2[i].second=G[i].first=G[i].second=F[i]=inf;
for(int i=;i<=N;i++) ans[i]=; scanf("%d",&Q);
for(int i=;i<=Q;i++)
Solve();
return ;
}
/*
10
2 1
3 2
4 3
5 4
6 1
7 3
8 3
9 4
10 1
5
2
6 1
5
2 7 3 6 9
1
8
4
8 7 10 3
5
2 9 3 5 8
*/