思路
虚树上DP
虚树相当于一颗包含了所有询问的关键点信息的树,包含的所有点只有询问点和它们的LCA,所以点数是\(2k\)级别的,这样的话复杂度就是\(O(\sum k)\),复杂度就对了
虚树重点就是虚树的构造
用栈可以进行虚树的构造
过程如下
设现在加入点u
如果栈为空或只有一个元素,直接加入即可(延长当前链)
如果LCA(u,S[top])=S[top],把u加入即可(延长树链)
否则证明u和S中的树链在lca的两个子树中,在dfn[lca]<=dfn[S[top-1]]的条件下,从S[top-1]向S[top]连边,然后弹出
如果最后lca=S[top],证明这颗子树构造完成,加入u即可
否则证明lca在S[top-1]和S[top]之间,从lca向S[top]连边,然后pop出S[top],lca入栈
最后把u加入即可
这题建出虚树之后就直接DP就好了
如果u不是关键点
\(DP[u]=\sum_{v\in son[u]} min(minx[v],DP[v])\)
如果u是关键点
\(DP[u]=minx[u]\)
minx[u]是断开1到u路径的最小代价
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <stack>
#include <vector>
#define int long long
using namespace std;
const int MAXN = 250010;
int n,m;
struct Graph{
vector<int> to[MAXN],wx[MAXN];
void addedge(int ui,int vi,int wi){
to[ui].push_back(vi);
wx[ui].push_back(wi);
}
}G1,G2;
int S[MAXN],topx,dfn[MAXN],dfs_clock,fa[MAXN][20],dep[MAXN],minx[MAXN],mark[MAXN];
void dfs1(int u,int f){
dep[u]=dep[f]+1;
dfn[u]=++dfs_clock;
fa[u][0]=f;
for(int i=1;i<20;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=0;i<G1.to[u].size();i++){
int vi=G1.to[u][i];
if(vi==f)
continue;
minx[vi]=min(G1.wx[u][i],minx[u]);
dfs1(vi,u);
}
}
int lca(int x,int y){
if(dep[x]<dep[y])
swap(x,y);
for(int i=19;i>=0;i--)
if(dep[fa[x][i]]>=dep[y])
x=fa[x][i];
if(x==y)
return x;
for(int i=19;i>=0;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
bool cmp(int a,int b){
return dfn[a]<dfn[b];
}
void insert(int u){
if(topx<=1){
S[++topx]=u;
return;
}
int Lca=lca(u,S[topx]);
if(Lca==S[topx]){
S[++topx]=u;
return;
}
while(topx>1&&dfn[Lca]<=dfn[S[topx-1]]){
G2.addedge(S[topx-1],S[topx],0);
topx--;
}
if(Lca!=S[topx]){
G2.addedge(Lca,S[topx],0);
S[topx]=Lca;
}
S[++topx]=u;
}
int dfs2(int u){
int ans=0;
for(int i=0;i<G2.to[u].size();i++)
ans+=min(minx[G2.to[u][i]],dfs2(G2.to[u][i]));
G2.to[u].clear();
if(mark[u]){
mark[u]=false;
return minx[u];
}
else
return ans;
}
vector<int> im;
signed main(){
scanf("%lld",&n);
for(int i=1;i<n;i++){
int a,b,c;
scanf("%lld %lld %lld",&a,&b,&c);
G1.addedge(a,b,c);
G1.addedge(b,a,c);
}
minx[1]=0x3f3f3f3f;
dfs1(1,0);
scanf("%lld",&m);
for(int i=1;i<=m;i++){
im.clear();
int x,k;
scanf("%lld",&k);
for(int j=1;j<=k;j++){
scanf("%lld",&x);
im.push_back(x);
mark[x]=true;
}
sort(im.begin(),im.end(),cmp);
insert(1);
for(int i=0;i<im.size();i++)
insert(im[i]);
while(topx>0){
G2.addedge(S[topx-1],S[topx],0);
topx--;
}
printf("%lld\n",dfs2(1));
}
return 0;
}