bzoj 3572世界树 虚树+dp

时间:2023-03-08 19:19:41

题目大意:

给一棵树,每次给出一些关键点,对于树上每个点,被离它最近的关键点(距离相同被标号最小的)控制

求每个关键点控制多少个点

分析:

虚树+dp

dp过程如下:

第一次dp,递归求出每个点子树中关键点到它距离最小值

第二次dp,用第一次的信息,从上往下转移,求出每个点到所有关键点中到它距离最小值

这里兼容性讨论一下,发现可以不用存次大值,因为若最小值来自要更新的子树,则子树中点到上面的点的距离一定不优

前两次dp求出了虚树中1,2类点被谁控制

第三次dp,对于每条边,找到断点,细节见代码

注意:

虚树中这样算会算漏很多原树中的点

如2-1-3树,根是1,关键点1,2,这样3会算漏

用tmp存1的sz,遍历虚树子树时把tmp减掉已经算过的,剩下的就属于1的控制

吐槽题目有毒,去行末空格就PE,%……&%……&@¥%&

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cctype>
using namespace std;
typedef long long LL;
const int M=300007;
const int INF=1e9;
inline int rd(){
    int x=0;bool f=1;char c=getchar();
    for(;!isdigit(c);c=getchar())if(c=='-')f=0;
    for(;isdigit(c);c=getchar())x=x*10+c-48;
    return x;
}

int n,m;

int g[M],hd[M],te,td;
struct edge{int y,next;}e[M<<1],dw[M];

int top[M],dfn[M],pid[M],tdfn;
int pre[M],dep[M],sz[M],son[M];

struct node{
    int x,id;
    node(int xx=0,int ii=0){x=xx;id=ii;}
}que[M];
bool operator <(node x,node y){
    if(x.x<y.x) return 1;
    if(x.x==y.x&&x.id<y.id) return 1;
    return 0;
}
int st[M],tot;

int ans[M];
int bl[M];
int bid[M];
int pbid[M];
node mn[M];//不求次大值,兼容性的问题

node gmn(node x,node y){
    return x<y?x:y;
}

bool cmp(node x,node y){return dfn[x.x]<dfn[y.x];}

void addedge(int x,int y){
    e[++te].y=y;e[te].next=g[x];g[x]=te;
}
void addlink(int x,int y){
    if(x==y) return;
    dw[++td].y=y;dw[td].next=hd[x];hd[x]=td;
}

void dfs1(int x){
    sz[x]=1;
    int p,y;
    for(p=g[x];p;p=e[p].next)
    if((y=e[p].y)!=pre[x]){
        pre[y]=x;
        dep[y]=dep[x]+1;
        dfs1(y);
        sz[x]+=sz[y];
        if(sz[y]>sz[son[x]]) son[x]=y;
    }
}

void dfs2(int x){
    pid[dfn[x]=++tdfn]=x;
    int p,y;
    if(son[x]){
        top[son[x]]=top[x];
        dfs2(son[x]);
    }
    for(p=g[x];p;p=e[p].next)
    if((y=e[p].y)!=pre[x]&&y!=son[x]){
        top[y]=y;
        dfs2(y);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=pre[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

int jumpc(int x,int y){
    while(dep[top[x]]>dep[y]){
        x=top[x];
        if(pre[x]==y) return x;
        x=pre[x];
    }
    return pid[dfn[y]+1];
}

int stp(int x,int y){
    int tp;
    while( y && (tp=dep[x]-dep[pre[top[x]]])<=y){//y&&
        y-=tp;
        x=pre[top[x]];
    }
    return pid[dfn[x]-y];
}

void vbuild(int z){
    sort(que+1,que+z+1,cmp);
    int i,x,anc;
    for(i=1;i<z;i++){
        x=LCA(que[i].x,que[i+1].x);
        hd[x]=0;
        bl[x]=0;
    }
    hd[1]=0;
    bl[1]=0;
    for(i=1;i<=z;i++){
        x=que[i].x;
        hd[x]=0;
        bl[x]=que[i].id;
        ans[que[i].id]=0;
    }

    td=tot=0;
    st[++tot]=1;
    for(i=1;i<=z;i++){
        x=que[i].x;
        anc=LCA(x,st[tot]);
        if(anc==st[tot]) st[++tot]=x;
        else{
            while(tot>1 && dep[anc]<=dep[st[tot-1]]){
                addlink(st[tot-1],st[tot]);
                tot--;
            }
            addlink(anc,st[tot]);
            st[tot]=anc;
            st[++tot]=x;
        }
    }
    for(i=1;i<tot;i++) addlink(st[i],st[i+1]);
}

void dp1(int x){
    if(!bl[x])mn[x]=node(INF,0);
    else mn[x]=node(0,pbid[bl[x]]);
    int p,y;
    node tp;
    for(p=hd[x];p;p=dw[p].next){
        y=dw[p].y;
        dp1(y);
        tp=mn[y];
        tp.x+=dep[y]-dep[x];
        mn[x]=min(mn[x],tp);
    }
}

void dp2(int x){
    if(!bl[x]) bl[x]=bid[mn[x].id];
    int p,y;
    node tp;
    for(p=hd[x];p;p=dw[p].next){
        y=dw[p].y;
        tp=mn[x];
        tp.x+=dep[y]-dep[x];
        mn[y]=min(mn[y],tp);
        dp2(y);
    }
}

void dp3(int x){
    int p,y,z,d,tp;
    int ss=sz[x];
    for(p=hd[x];p;p=dw[p].next){
        y=dw[p].y;
        d=mn[x].x+mn[y].x+dep[y]-dep[x]-1;//加上mn,最后要减1,算出要真正要竞争领地的两点间点数,点数!
        z=jumpc(y,x);
        ss-=sz[z];
        dp3(y);
        if(bl[x]==bl[y]) ans[bl[x]]+=sz[z]-sz[y];
        else{
            if(d%2==0){
                tp=stp(y,d/2-mn[y].x);
            }
            else{
                if(pbid[bl[y]]<pbid[bl[x]]) tp=stp(y,d/2+1-mn[y].x);//
                else tp=stp(y,d/2-mn[y].x);
            }
            ans[bl[x]]+=sz[z]-sz[tp];
            ans[bl[y]]+=sz[tp]-sz[y];
        }
    }
    ans[bl[x]]+=ss;//算漏的补上
}

int main(){
    freopen("a.txt","r",stdin);
    int i,x,y,z;
    n=rd();
    for(i=1;i<n;i++){
        x=rd(),y=rd();
        addedge(x,y);
        addedge(y,x);
    }

    dep[1]=pre[1]=0;
    dfs1(1);
    top[1]=1;
    dfs2(1);

    m=rd();
    while(m--){
        z=rd();
        for(i=1;i<=z;i++){
            que[i].x=rd();
            bid[que[i].x]=i;
            que[i].id=i;
            pbid[i]=que[i].x;
        }
        vbuild(z);
        dp1(1);
        dp2(1);
        dp3(1);
        for(i=1;i<=z;i++) printf("%d ",ans[i]);
        puts("");
    }
    return 0;
}