hdu5788 Level Up (2016多校第五场1008) 主席树

时间:2021-11-12 00:18:27

很显然如果改变一个点,那么如果这个点的值小于等于某个父亲的中位数的值,那么这个父亲的值就要变成比中位数大的最小的值,那么每个点对应了两个值,我们求这两个值的时候只需要用dfs来作为序号扔到主席树里,然后查找就行。奇怪的是用第k小的log^2的方法居然比直接找log的方法还快。

下面给出复杂度为log的代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<cmath>
using namespace std;
const int maxn=100005;
struct pi{
    int sum;
    int lson;
    int rson;
}pp[maxn*23];
int root[maxn],tot;
void build(int cnt,int l,int r){
    pp[cnt].sum=0;
    if(l==r) return;
    pp[cnt].lson=tot+1;
    tot++;
    build(tot,l,(l+r)/2);
    pp[cnt].rson=tot+1;
    tot++;
    build(tot,(l+r)/2+1,r);
}
void merg(int qq,int cnt,int n,int p,int k){
    int le,ri,mid;
    le=1;
    ri=n;
    while(le<=ri){
        mid=(le+ri)/2;
        pp[cnt]=pp[qq];
        pp[cnt].sum+=k;
        if (le==ri) break;
        if(p<=mid){
            pp[cnt].lson=tot+1;
            tot++;
            ri=mid;
            cnt=tot;
            qq=pp[qq].lson;
        }
        else{
            pp[cnt].rson=tot+1;
            tot++;
            le=mid+1;
            cnt=tot;
            qq=pp[qq].rson;
        }
    }
}
int find(int pre,int cnt,int l,int r,int k){
    if(l==r) return l;
    int w=pp[pp[cnt].lson].sum-pp[pp[pre].lson].sum;
    if(w>=k){
        return find(pp[pre].lson,pp[cnt].lson,l,(l+r)>>1,k);
    }
    return find(pp[pre].rson,pp[cnt].rson,(l+r)/2+1,r,k-w);
}
int a[maxn];
vector<int>g[maxn];
int b[maxn];
int cnt[maxn],r[maxn],no;
int x[maxn],y[maxn];
int size[maxn];
int to[maxn];
void dfs(int u){
    size[u]=1;
    cnt[u]=++no;
    to[no]=u;
    for(int v:g[u]){
        dfs(v);
        size[u]+=size[v];
    }
    r[u]=no;
}
long long ans;
long long sum[maxn];
void dfs1(int u){
    sum[u]=b[x[u]];
    for(int v:g[u]){
        dfs1(v);
        sum[u]+=sum[v];
    }
}
long long bit[maxn];
int low(int p){
    return p&(-p);
}
void merg(int p,int n,int k){
    while (p<=n) {
        bit[p]+=k;
        p=p+low(p);
    }
}
long long query(int p){
    long long s=0;
    while (p) {
        s+=bit[p];
        p=p-low(p);
    }
    return s;
}
void dfs2(int u,long long all,int n){
    merg(1,n,b[y[u]]);
    merg(x[u]+1,n,b[x[u]]-b[y[u]]);
    long long s=query(a[u]);
    s=all+s+sum[u]-b[x[u]];
    ans=max(ans,s);
    for(int v:g[u]){
        dfs2(v,all+sum[u]-b[x[u]]-sum[v],n);
    }
    merg(1,n,-b[y[u]]);
    merg(x[u]+1,n,-b[x[u]]+b[y[u]]);
}
int main()
{
    int n;
    while (scanf("%d",&n)!=EOF) {
        for(int i=1;i<=n;i++){ scanf("%d",&a[i]);
            b[i]=a[i];
        }
        for(int i=1;i<=n;i++) g[i].clear();
        for(int i=2;i<=n;i++){
            int p;
            scanf("%d",&p);
            g[p].push_back(i);
        }
        sort(b+1,b+1+n);
        for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+n,a[i])-b;
        b[100001]=100000;
        no=0;
        tot=0;
        dfs(1);
        build(0,1,n);
        for(int i=1;i<=n;i++){
            bit[i]=0;
            int w=to[i];
            root[i]=++tot;
            merg(root[i-1],root[i],n,a[w],1);
        }
        for(int i=1;i<=n;i++){
            x[i]=find(root[cnt[i]-1],root[r[i]],1,n,(size[i]+1)/2);
            if(size[i]==1) y[i]=100001;
            else
            y[i]=find(root[cnt[i]-1],root[r[i]],1,n,(size[i]+1)/2+1);
        }
        ans=0;
        dfs1(1);
        ans=sum[1];
        dfs2(1,0,n);
        printf("%lld\n",ans);
    }
}