UOJ #11. 【UTR #1】ydc的大树

时间:2023-03-09 16:32:08
UOJ #11. 【UTR #1】ydc的大树

题目描述:

ydc有一棵n个结点的黑白相间的大树,从1到n编号。

这棵黑白树中有m个黑点,其它都是白点。

对于一个黑点我们定义他的好朋友为离他最远的黑点。如果有多个黑点离它最远那么都是它的好朋友。两点间的距离定义为两点之间的最短路的长度。

现在你要摧毁一个白点。

摧毁后有一些黑点会不高兴。一个黑点不高兴当且仅当他不能到达任何一个在摧毁那个白点前的好朋友。

请你最大化不高兴的黑点数。

解题报告:

套路题啊,直接提黑点重心到根,那么这样就可以保证每一个黑点到其最远的黑点一定经过根节点了,那么就可以开始枚举白点并讨论。

首先白点的子树内的黑点都是被截断的,然后考虑其是否在到根最长路和次长路上:

如果有三个及以上最长路显然是无法截断的。

如果仅有两个及以下,那么如果在最长路上,且最长路唯一,那么除了该子树内最长路所在子树外的所有黑点一定都能被截断。

如果是在次长路上,那么最长路上的黑点就要走到次长路上,一定会被截断

#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#define RG register
#define il inline
#define iter iterator
#define Max(a,b) ((a)>(b)?(a):(b))
#define Min(a,b) ((a)<(b)?(a):(b))
using namespace std;
const int N=1e5+5;
int num=0,head[N],nxt[N<<1],to[N<<1],dis[N<<1],n,m;bool col[N];
void link(int x,int y,int z){nxt[++num]=head[x];to[num]=y;dis[num]=z;head[x]=num;}
int root=0,f[N],g[N],bel[N],sz[N];
void getroot(int x,int last,int *p){
int u;if(col[x] && p[root]<p[x])root=x;
for(int i=head[x];i;i=nxt[i]){
u=to[i];if(u==last)continue;
p[u]=p[x]+dis[i];
getroot(u,x,p);
}
}
void dfs(int x,int last,int dist){
int u;
if(last==root)bel[x]=x;
else bel[x]=bel[last];
sz[x]=col[x];
for(int i=head[x];i;i=nxt[i]){
u=to[i];if(u==last)continue;
dfs(u,x,dist+dis[i]);
sz[x]+=sz[u];
}
if(sz[x]){
f[x]=dist;g[x]=1;
for(int i=head[x];i;i=nxt[i]){
u=to[i];if(u==last)continue;
if(f[u]>f[x]){f[x]=f[u];g[x]=g[u];}
else if(f[x]==f[u])g[x]+=g[u];
}
}
}
void work()
{
int x,y,z;
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++){scanf("%d",&x);col[x]=true;}
for(int i=1;i<n;i++){
scanf("%d%d%d",&x,&y,&z);
link(x,y,z);link(y,x,z);
}
root=0;getroot(1,1,f);
x=root;root=0;getroot(x,x,g);
memset(f,0,sizeof(f));
getroot(root,root,f);root=0;
for(int i=1;i<=n;i++)
if(g[i]+f[i]==f[x]){
if(!root)root=i;
else if(abs(f[i]-g[i])<abs(f[root]-g[root]))root=i;
}
for(RG int i=0;i<N;i++)f[i]=g[i]=0;
dfs(root,root,0);
int mx=0,cmx=0,mxid=0,cmxid=0,ans=0,answer=0,tot=0,mxcnt=0;
if(!col[root])ans=m,answer=1;
for(int i=head[root];i;i=nxt[i]){
int u=to[i];if(!sz[u])continue;
if(f[u]>mx){cmx=mx;cmxid=mxid;mx=f[u];mxid=u;mxcnt=0;}
else if(f[u]==mx){
if(f[mxid]==f[cmxid]){mxid=0;cmxid=0;mxcnt=3;}
else{cmxid=mxid;cmx=mx;mxid=u;mx=f[u];}
}
else if(f[u]>cmx){cmxid=u;cmx=f[u];mxcnt=0;}
}
if(mxcnt==3)mxid=0,cmxid=0;
for(int i=1;i<=n;i++){
if(col[i]==1 || i==root)continue;
tot=sz[i];
if(tot && f[i]==f[bel[i]] && g[i]==g[bel[i]]){
if(bel[i]==mxid){
if(f[mxid]!=f[cmxid])tot+=m-sz[mxid];
else tot+=sz[cmxid];
}
else if(bel[i]==cmxid)tot+=sz[mxid];
}
if(tot>ans){ans=tot;answer=1;}
else if(ans==tot)answer++;
}
printf("%d %d\n",ans,answer);
}
int main(){work();return 0;}