很显然如果改变一个点,那么如果这个点的值小于等于某个父亲的中位数的值,那么这个父亲的值就要变成比中位数大的最小的值,那么每个点对应了两个值,我们求这两个值的时候只需要用dfs来作为序号扔到主席树里,然后查找就行。奇怪的是用第k小的log^2的方法居然比直接找log的方法还快。
下面给出复杂度为log的代码
#include<cstdio> #include<cstring> #include<algorithm> #include<iostream> #include<vector> #include<cmath> using namespace std; const int maxn=100005; struct pi{ int sum; int lson; int rson; }pp[maxn*23]; int root[maxn],tot; void build(int cnt,int l,int r){ pp[cnt].sum=0; if(l==r) return; pp[cnt].lson=tot+1; tot++; build(tot,l,(l+r)/2); pp[cnt].rson=tot+1; tot++; build(tot,(l+r)/2+1,r); } void merg(int qq,int cnt,int n,int p,int k){ int le,ri,mid; le=1; ri=n; while(le<=ri){ mid=(le+ri)/2; pp[cnt]=pp[qq]; pp[cnt].sum+=k; if (le==ri) break; if(p<=mid){ pp[cnt].lson=tot+1; tot++; ri=mid; cnt=tot; qq=pp[qq].lson; } else{ pp[cnt].rson=tot+1; tot++; le=mid+1; cnt=tot; qq=pp[qq].rson; } } } int find(int pre,int cnt,int l,int r,int k){ if(l==r) return l; int w=pp[pp[cnt].lson].sum-pp[pp[pre].lson].sum; if(w>=k){ return find(pp[pre].lson,pp[cnt].lson,l,(l+r)>>1,k); } return find(pp[pre].rson,pp[cnt].rson,(l+r)/2+1,r,k-w); } int a[maxn]; vector<int>g[maxn]; int b[maxn]; int cnt[maxn],r[maxn],no; int x[maxn],y[maxn]; int size[maxn]; int to[maxn]; void dfs(int u){ size[u]=1; cnt[u]=++no; to[no]=u; for(int v:g[u]){ dfs(v); size[u]+=size[v]; } r[u]=no; } long long ans; long long sum[maxn]; void dfs1(int u){ sum[u]=b[x[u]]; for(int v:g[u]){ dfs1(v); sum[u]+=sum[v]; } } long long bit[maxn]; int low(int p){ return p&(-p); } void merg(int p,int n,int k){ while (p<=n) { bit[p]+=k; p=p+low(p); } } long long query(int p){ long long s=0; while (p) { s+=bit[p]; p=p-low(p); } return s; } void dfs2(int u,long long all,int n){ merg(1,n,b[y[u]]); merg(x[u]+1,n,b[x[u]]-b[y[u]]); long long s=query(a[u]); s=all+s+sum[u]-b[x[u]]; ans=max(ans,s); for(int v:g[u]){ dfs2(v,all+sum[u]-b[x[u]]-sum[v],n); } merg(1,n,-b[y[u]]); merg(x[u]+1,n,-b[x[u]]+b[y[u]]); } int main() { int n; while (scanf("%d",&n)!=EOF) { for(int i=1;i<=n;i++){ scanf("%d",&a[i]); b[i]=a[i]; } for(int i=1;i<=n;i++) g[i].clear(); for(int i=2;i<=n;i++){ int p; scanf("%d",&p); g[p].push_back(i); } sort(b+1,b+1+n); for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+n,a[i])-b; b[100001]=100000; no=0; tot=0; dfs(1); build(0,1,n); for(int i=1;i<=n;i++){ bit[i]=0; int w=to[i]; root[i]=++tot; merg(root[i-1],root[i],n,a[w],1); } for(int i=1;i<=n;i++){ x[i]=find(root[cnt[i]-1],root[r[i]],1,n,(size[i]+1)/2); if(size[i]==1) y[i]=100001; else y[i]=find(root[cnt[i]-1],root[r[i]],1,n,(size[i]+1)/2+1); } ans=0; dfs1(1); ans=sum[1]; dfs2(1,0,n); printf("%lld\n",ans); } }