题意:给出一棵无根树,每个节点有一个权值,现在要让dfs序的前k个结点的最小值最大,求出这个值。
考虑二分答案,把>=答案的点标记为1,<答案的点标记为0,现在的任务时使得dfs序的前k个节点都为1.
考虑树形DP。
用dp[u]表示从节点u开始在子树中进行dfs最多可以经过多少个为1的结点,显然,若某一个子树中节点全为1,那么这个可以加到dp[u]中,此外还可以在不全为1的子树中挑选一个加到dp[u]上。
那么答案就是从标记为1的节点当做根,选两颗不完全子树和所有的完全子树(包括从父亲向上的部分)。
那么如果从父亲向上的部分是不完全子树呢,那等价于从这颗不完全子树上的一个深度最小的点做上面的计算一下。所以不需要考虑从父亲向上的部分是不完全子树这个情况。
时间复杂度O(nlogn).
# include <cstdio>
# include <cstring>
# include <cstdlib>
# include <iostream>
# include <vector>
# include <queue>
# include <stack>
# include <map>
# include <bitset>
# include <set>
# include <cmath>
# include <algorithm>
using namespace std;
# define lowbit(x) ((x)&(-x))
# define pi acos(-1.0)
# define eps 1e-
# define MOD
# define INF
# define mem(a,b) memset(a,b,sizeof(a))
# define FOR(i,a,n) for(int i=a; i<=n; ++i)
# define FO(i,a,n) for(int i=a; i<n; ++i)
# define bug puts("H");
# define lch p<<,l,mid
# define rch p<<|,mid+,r
# define mp make_pair
# define pb push_back
typedef pair<int,int> PII;
typedef vector<int> VI;
# pragma comment(linker, "/STACK:1024000000,1024000000")
typedef long long LL;
int Scan() {
int x=,f=;char ch=getchar();
while(ch<''||ch>''){if(ch=='-')f=-;ch=getchar();}
while(ch>=''&&ch<=''){x=x*+ch-'';ch=getchar();}
return x*f;
}
const int N=;
//Code begin... struct Edge{int p, next;}edge[N<<];
int node[N], head[N], cnt=, dp[N], date[N], siz[N], tag[N], sum, n, K, ans;
bool flag[N]; void add_edge(int u, int v){edge[cnt].p=v; edge[cnt].next=head[u]; head[u]=cnt++;}
void dfs1(int x, int fa, int val){
siz[x]=;
if (node[x]<val) tag[x]=, ++sum;
if (node[x]<val) flag[x]=true;
for (int i=head[x]; i; i=edge[i].next) {
int v=edge[i].p;
if (v==fa) continue;
dfs1(v,x,val); siz[x]+=siz[v]; tag[x]+=tag[v]; flag[x]|=flag[v];
}
}
void dfs2(int x, int fa, int val){
dp[x]=;
int f=, s=;
for (int i=head[x]; i; i=edge[i].next) {
int v=edge[i].p;
if (v==fa) continue;
dfs2(v,x,val);
if (!flag[v]) dp[x]+=siz[v];
else if (node[v]>=val) {
if (dp[v]>f) s=f, f=dp[v];
else if (dp[v]>s) s=dp[v];
}
}
dp[x]+=f;
if (node[x]>=val) {
if (tag[x]==sum) ans=max(ans,dp[x]+s+n-siz[x]);
else ans=max(ans,dp[x]+s);
}
}
bool check(int x){
mem(siz,); mem(dp,); mem(flag,false); mem(tag,); sum=ans=;
dfs1(,,x); dfs2(,,x);
return ans>=K;
}
int main ()
{
int u, v;
scanf("%d%d",&n,&K);
FOR(i,,n) scanf("%d",node+i), date[i]=node[i];
FO(i,,n) scanf("%d%d",&u,&v), add_edge(u,v), add_edge(v,u);
sort(date+,date+n+);
int l=, r=n+, mid;
while (l<r) {
mid=(l+r)>>;
if (l==mid) break;
if (check(date[mid])) l=mid;
else r=mid;
}
printf("%d\n",date[l]);
return ;
}