UOJ#374. 【ZJOI2018】历史 贪心,LCT

时间:2022-01-16 21:52:12

原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ374.html

题解

想出正解有点小激动。

不过因为傻逼错误调到自闭。不如贺题

首先我们考虑如何 $O(n)$ 求一个答案。

首先,计算两条路径的贡献时,由于两国连续交战数次只算一次,所以我们可以只看这两条路径的交的最深点。

也就是说,我们可以对于每一个点分开考虑,假装他的同一个子树的所有点颜色相同,不同子树的点颜色不同,它本身也当作一个子树看。

假设 x 是当前节点,y 是 x 的子树。

设 size[v] 表示 v 子树的所有节点的 a[v] 之和。

那么我们容易推出两个断论:

1. x 节点对答案的贡献最多不超过 size[x] - 1 。

2. 设 max(size[y]) 表示 x 的所有子树中 size 最大的子树的 size ,当 max(size[y]) - 1 >= size[x] - max(size[y]) 时,都有使 x 的贡献为 size[x] - 1 的方案;否则, x 节点对答案的贡献最大为 max(size[y]) - 1 - (size[x] - max(size[y])) = 2max(size[y]) - 1 - size[x]

所以贡献为

$$min(size[x] - 1, 2max(size[y]) - 1 - size[x])$$

设 val[x] = size[x] - 1 ,可以证明 $\sum_{y} val[y] \leq \sum_{y} (size[y] - 1) \leq size[x] - 1 = val[x]$

则这个式子会更加好看(把常数消掉了,然并卵):

$$min(val[x],2max(val[y])-val[x])$$

现在已经可以轻松拿到 30 分了。

考虑 100 分怎么做。

我们可以发现好像操作的时候所有的 max(val[y]) 的 y 的变化次数不多啊!

于是我们可以想到 LCT 维护这个东西。

这里的 LCT 不是传统的 LCT 。

如果 val[x] >= val[fa[x]] 那么我们将 x 作为 fa[x] 的重儿子。我们可以发现每一个节点只有一个重儿子:由于 $\sum_{y} val[y] \leq val[x]$ ,而且两个子树的特殊情况特殊考虑一下发现也是对的。

这样的话,可能会有节点没有重儿子。

但是,从任意一个节点到根走过的轻边条数是 $O(\log \sum a[i])$ 的,因为每走过一条轻边,子树权值和至少翻一倍。

然后你发现修改一个点的时候只要修改它到根路径上的所有点权(val[x]),而且对于重链,它对答案的贡献是不变的!

所以只要对 $O(\log\sum a[i])$ 个轻边处理就好了。

由于要链上修改点权,所以每一段重链都要预先下传标记。

总的来说,这样做要跳过 $O(\log \sum a[i])$ 段重链,每段重链 splay 需要花费 $O(\log n)$ 的时间复杂度,所以看上去复杂度是 $O(n\log^2 n)$ 的。80分很开心了吧!更开心的是如果交上去的话它能 AC 。

这是为什么呢?我们考虑势能分析,定义势函数为 $\sum_{ LCT 上所有节点 }\ \ \ \ \ \log (该节点在splay结构上的size + 它的虚子树的size)$ ,类似于 splay 复杂度的证明,可以证明这个东西是均摊 $O(\log \sum a[i] + \log n)$ 的。

这里不把证明写出来了。懒得写了。

最终时间复杂度为 $O(n\log(n\sum a[i]))$ 。

注意在写代码的时候要注意一些细节。对于节点本身的贡献我们可以把每一个点拆成两个点,第一个点先连原先所有子树,再新建第二个点,让他们连起来,并使第一个点是第二个点的父亲,第二个点的权值为 a[x] - 1 。这样可以减掉几个 if 。

注意链上修改的时候,不是直接给根打标记就完事了,因为这里的 LCT 比较奇怪,所以直接打标记会多给一段后缀重链带来修改,所以我们还要再在这个后缀重链上打个标记来抵消根上的标记。

代码

#pragma GCC optimize("Ofast","inline")
#include <bits/stdc++.h>
#define clr(x) memset(x,0,sizeof (x))
#define For(i,a,b) for (int i=a;i<=b;i++)
#define Fod(i,b,a) for (int i=b;i>=a;i--)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define _SEED_ ('C'+'L'+'Y'+'A'+'K'+'I'+'O'+'I')
#define outval(x) printf(#x" = %d\n",x)
#define outvec(x) printf("vec "#x" = ");for (auto _v : x)printf("%d ",_v);puts("")
#define outtag(x) puts("----------"#x"----------")
using namespace std;
typedef long long LL;
LL read(){
LL x=0,f=0;
char ch=getchar();
while (!isdigit(ch))
f|=ch=='-',ch=getchar();
while (isdigit(ch))
x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return f?-x:x;
}
const int N=400005*2;
int n,m;
LL a[N],s[N],v[N],f[N];
vector <int> e[N];
LL ans=0;
void dfs(int x,int pre){
f[x]=pre;
s[x]=a[x];
LL Mx=a[x]-1;
for (auto y : e[x])
if (y!=pre){
dfs(y,x);
s[x]+=s[y];
Mx=max(Mx,v[y]);
}
v[x]=s[x]-1;
ans+=min(v[x],(v[x]-Mx)*2);
}
int fa[N],son[N][2];
LL val[N],Add[N],Mxv[N];
void LCT_build(){
clr(son),clr(val),clr(Add);
For(i,1,n){
fa[i]=i+n,val[i]=a[i]-1;
fa[i+n]=f[i]?f[i]+n:0,val[i+n]=v[i];
}
For(i,1,n*2){
Mxv[i]=val[i];
if (fa[i]&&val[i]*2>=val[fa[i]])
son[fa[i]][1]=i;
}
}
#define ls son[x][0]
#define rs son[x][1]
int isroot(int x){
return son[fa[x]][0]!=x&&son[fa[x]][1]!=x;
}
int wson(int x){
return son[fa[x]][1]==x;
}
void pushup(int x){
Mxv[x]=max(val[x],max(Mxv[ls],Mxv[rs]));
}
void pushdown(int x){
if (Add[x]){
if (ls)
val[ls]+=Add[x],Add[ls]+=Add[x],Mxv[ls]+=Add[x];
if (rs)
val[rs]+=Add[x],Add[rs]+=Add[x],Mxv[rs]+=Add[x];
Add[x]=0;
}
}
void pushadd(int x){
if (!isroot(x))
pushadd(fa[x]);
pushdown(x);
}
void rotate(int x){
if (isroot(x))
return;
int y=fa[x],z=fa[y],L=wson(x),R=L^1;
if (!isroot(y))
son[z][wson(y)]=x;
fa[x]=z,fa[y]=x,fa[son[x][R]]=y;
son[y][L]=son[x][R],son[x][R]=y;
pushup(y),pushup(x);
}
void splay(int x){
pushadd(x);
for (int y=fa[x];!isroot(x);rotate(x),y=fa[x])
if (!isroot(y))
rotate(wson(x)==wson(y)?y:x);
}
void False_Access(int x){//pushdown the tags
while (x)
splay(x),x=fa[x];
}
void update(int x,LL w){
False_Access(x);
if (rs)
val[rs]-=w,Add[rs]-=w,Mxv[rs]-=w;
while (fa[x]){
int y=fa[x];
if (son[y][1]){
if (val[y]+w>Mxv[son[y][1]]*2){
ans+=val[y]+w-(val[y]-Mxv[son[y][1]])*2;
son[y][1]=0;
}
else
ans+=w*2;
}
else
ans+=w;
if ((Mxv[x]+w)*2>val[y]+w){
ans+=(val[y]+w-(Mxv[x]+w))*2-(val[y]+w);
son[y][1]=x;
}
else {
val[x]+=w,Add[x]+=w,Mxv[x]+=w;
if (son[y][1])
val[son[y][1]]-=w,Add[son[y][1]]-=w,Mxv[son[y][1]]-=w;
}
x=y;
}
val[x]+=w,Add[x]+=w,Mxv[x]+=w;
}
#undef ls
#undef rs
int main(){
n=read(),m=read();
For(i,1,n)
a[i]=read();
For(i,1,n-1){
int x=read(),y=read();
e[x].pb(y),e[y].pb(x);
}
dfs(1,0);
printf("%lld\n",ans);
LCT_build();
For(i,1,m){
int x=read(),w=read();
update(x,w);
printf("%lld\n",ans);
}
return 0;
}