贴下以前写的代码
比赛前我准备着重看的
主席树 树dp 字符串
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 1e5+5;
const int MOD = 1e9+7;
int N;
int A[MAXN];
vector<int> mp[MAXN];
int val[MAXN][2];
/**********President Tree**********/
int tot;
struct Node{
int ls, rs, cc;
Node(int a=0, int b=0, int c=0):ls(a), rs(b), cc(c){}
}tree[MAXN*20];
int sta[MAXN];
int Build(int pos,int l,int r){
int rt = tot++;
tree[rt].cc = 1;
tree[rt].ls = tree[rt].rs = 0;
if(l == r) return rt;
int mid = (l+r) >> 1;
if(pos <= mid) tree[rt].ls = Build(pos, l,mid);
else tree[rt].rs = Build(pos, mid+1,r);
return rt;
}
int Merge(int x,int y,int l,int r){
if(x == 0 || y == 0) return x+y;
tree[x].cc += tree[y].cc;
if(l == r) {
tree[x].ls = tree[x].rs = 0;
return x;
}
int mid = (l+r) >> 1;
tree[x].ls = Merge(tree[x].ls, tree[y].ls, l, mid);
tree[x].rs = Merge(tree[x].rs, tree[y].rs,mid+1, r);
return x;
}
int Query(int rt,int K,int l,int r){
if(tree[rt].cc < K) return 100000;
if(l == r) return l;
int mid = (l+r)>>1;
if(tree[tree[rt].ls].cc >= K) return Query(tree[rt].ls, K, l ,mid);
else return Query(tree[rt].rs, K-tree[tree[rt].ls].cc, mid+1, r);
}
void dfs(int x){
sta[x] = Build(A[x], 1, 100000);
for(int i = 0; i < mp[x].size(); ++i){
int y = mp[x][i];
dfs(y);
sta[x] = Merge(sta[x], sta[y], 1, 100000);
}
// printf("%d\n",sta[x]);
val[x][0] = Query(sta[x], tree[sta[x]].cc+1>>1, 1, 100000);
val[x][1] = Query(sta[x], (tree[sta[x]].cc+1>>1)+1, 1, 100000);
}
/**************BIT*************/
ll bitt[MAXN];
void Add(int pos,int num){
for(int i = pos; i > 0; i -= i&-i)
bitt[i] += num;
}
ll Sum(int pos){
ll ans = 0;
for(int i = pos; i <= 100000; i += i&-i)
ans += bitt[i];
return ans;
}
ll dfs2(int x){
Add(val[x][0], val[x][1]-val[x][0]);
ll ans = Sum(val[x][0]);
for(int i = 0; i < mp[x].size(); ++i){
int y = mp[x][i];
ans = max(ans, dfs2(y));
}
Add(val[x][0], val[x][0]-val[x][1]);
return ans;
}
int main(){
while(~scanf("%d",&N)){
memset(bitt,0,sizeof(bitt));
tot = 1;
for(int i = 1; i <= N; ++i) scanf("%d",&A[i]);
for(int i = 1; i <= N; ++i) mp[i].clear();
for(int i = 2; i <= N; ++i){
int a; scanf("%d",&a);
mp[a].push_back(i);
}
dfs(1);
ll sum = 0;
for(int i = 1; i <= N; ++i)
sum += val[i][0];
printf("%lld\n", sum + dfs2(1));
}
return 0;
}