每个点的最优取值范围是一个区间,将叶子一层层剥去,得到一棵有根树,父亲的取值范围由儿子推得,时间复杂度$O(n\log n)$。
#include<cstdio>
#include<algorithm>
#define N 500010
int n,m,i,j,x,y,c,g[N],v[N<<1],nxt[N<<1],ed,d[N],l[N],r[N];
int del[N],G[N],V[N],NXT[N],h,t,q[N],f[N],a[N<<1];
long long ans,s,tmp,now;
inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';}
inline int abs(int x){return x>0?x:-x;}
inline void add(int x,int y){d[x]++;v[++ed]=y;nxt[ed]=g[x];g[x]=ed;}
inline void adde(int x,int y){f[y]=x;V[++ed]=y;NXT[ed]=G[x];G[x]=ed;}
void dfs(int x){
if(!G[x])return;
int i;
for(i=G[x];i;i=NXT[i])dfs(V[i]);
for(tmp=1LL<<60,j=c=s=m=0,i=G[x];i;i=NXT[i])a[m++]=l[V[i]],a[m++]=r[V[i]],c--,s+=l[V[i]];
for(std::sort(a,a+m),i=0;i<m;i++){
c++,s-=a[i],now=s+1LL*a[i]*c;
if(now<tmp)l[x]=a[i],tmp=now;
if(now==tmp)r[x]=a[i];
}
ans+=tmp;
}
int main(){
read(n),read(m);
for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x);
for(i=1;i<=m;i++)read(l[i]),r[i]=l[i];
if(n==m){
for(i=1;i<=n;i++)for(j=g[i];j;j=nxt[j])ans+=abs(l[i]-l[v[j]]);
return printf("%lld",ans/2),0;
}
for(ed=0,i=h=1;i<=m;i++)del[q[++t]=i]=1;
for(;h<=t;h=x+1){
for(i=h;i<=t;i++)for(j=g[q[i]];j;j=nxt[j])if(!del[v[j]])adde(v[j],q[i]);
for(i=h,x=t;i<=x;i++)for(j=g[q[i]];j;j=nxt[j])if(!del[v[j]])if((--d[v[j]])<=1)del[q[++t]=v[j]]=1;
}
for(i=1;i<=n;i++)if(!f[i])adde(0,i);
dfs(0);
return printf("%lld",ans),0;
}