题目链接:糖果公园
听说这是一道树上莫队的入门题,于是我就去写了……顺便复习了一下莫队的各种姿势。
首先,我们要在树上使用莫队,那么就需要像序列一样给树分块。这个分块的过程就是王室联邦这道题(vfleaking大神的博客里也有讲),当然也可以按照\(dfs\)序什么的进行分块。
但我还是想在这里讲一下我的分块做法(第一种)。我们可以先设一个块的大小$S$,然后对整棵树进行$dfs$。每次我们$dfs$完一棵子树回溯时,我们就把这个点$u$加入到一个队列中,更新$fa_u$的$siz$。这里$siz_u$记录的是以$u$为根的子树中为分好块的节点数。当我们发现$siz_u$超过了$S$时,我们就可以给这些节点新建一个块。最后还会剩下不超过$S$个节点,我们可以新建一个块,也可以把这些点丢到最后新建的那个块中。如果采用后一种方式,那么就保证了每个块的大小都在$[S,3S)$之间且最多只有一个块的大小会大于$2S$。于是我们就得到了一种优美的分块方法。(如果还不懂的话可以去看代码)
然后,这道题带了修改(其实并没有什么不同,只是需要修改块的大小以及多一个时间流逝的操作),于是给每个操作多一个时间戳即可。树上莫队和序列莫队最大的不同就在于如何转移。也就是我们如何由路径$(u_i,v_j)$走到路径$(u_j,v_j)$上去。其实就是先把$u_i$走到$u_j$,再把$v_i$走到$v_j$即可。注意,这里为了方便考虑,上述路径均没有把深度最小的点给算进去。这里没有讲清楚……还是往后看吧……
至于证明吗……我们先设几个东西(下列内容参考vfleaking大神的博客)……
设$S(u,v)$表示路径$(u,v)$上所有点的集合,$T(u,v)$表示$S(u,v)$去掉$lca(u,v)$之后剩下的点集。那么显然有:$$T(u,v)=S(root,u) \ xor \ S(root,v)$$
这里的$xor$指的是集合的对称差。简单来说就是把出现偶数次的节点给删掉。这个操作显然和异或差不多,性质也差不多。
于是我们可以考虑一下从$T(u_i,v_i)$转移到$T(u_j,v_j)$需要改变一些什么。$$\because T(u_i,v_i)=S(root,u_i)\ xor \ S(root,v_i)\quad T(u_j,v_j)=S(root,u_j)\ xor \ S(root,v_j)$$ $$\therefore T(u_i,v_i)\ xor \ T(u_j,v_j)=T(u_i,u_j)\ xor \ T(v_i,v_j)$$ $$\therefore T(u_j,v_j)=T(u_i,v_i)\ xor \ T(u_i,u_j)\ xor \ T(v_i,v_j)$$
于是我们从$T(u_i,v_i)$转移到$T(u_j,v_j)$只需要把点集$T(u_i,u_j)$和$T(v_i,v_j)$中的点存在性全部取反即可。
至于如何$O(1)$地进行插入和删除就不需要我多说了吧……
其实我感觉有vfleaking大神的博客应该不需要我的这一篇才对
下面贴代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
#define maxn 100010 using namespace std;
typedef long long llg; int n,m,q,V[maxn],co[maxn],num[maxn];
int fa[maxn][17],dep[maxn],siz[maxn];
int head[maxn],next[maxn<<1],to[maxn<<1],tt;
int qt[maxn],fr[maxn],qto,nblo,sizblo;
llg sw[maxn],now,ans[maxn]; bool vis[maxn];
struct data{
int u,v,t;
bool operator < (const data &h)const{
if(fr[u]!=fr[h.u]) return fr[u]<fr[h.u];
if(fr[v]!=fr[h.v]) return fr[v]<fr[h.v];
return t<h.t;
}
}s[maxn],cha[maxn<<1]; int getint(){
int w=0;bool q=0;
char c=getchar();
while((c>'9'||c<'0')&&c!='-') c=getchar();
if(c=='-') c=getchar(),q=1;
while(c>='0'&&c<='9') w=w*10+c-'0',c=getchar();
return q?-w:w;
} void link(int x,int y){
to[++tt]=y;next[tt]=head[x];head[x]=tt;
to[++tt]=x;next[tt]=head[y];head[y]=tt;
} void dfs(int u,int ff){
qt[++qto]=u;
fa[u][0]=ff; dep[u]=dep[ff]+1;
for(int i=1,now=2;now<dep[u];i++,now<<=1)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=head[u],v;v=to[i],i;i=next[i])
if(v!=ff){
dfs(v,u); siz[u]+=siz[v];
if(siz[u]>=sizblo){
nblo++; siz[u]=0;
while(qt[qto]!=u) fr[qt[qto--]]=nblo;
}
}
siz[u]++;
} void change(int u){
now-=V[co[u]]*sw[num[co[u]]];
if(vis[u]) num[co[u]]--,vis[u]=0;
else num[co[u]]++,vis[u]=1;
now+=V[co[u]]*sw[num[co[u]]];
} void timego(int x){
if(!cha[x].t) return;
bool ww=vis[cha[x].t];
if(ww) change(cha[x].t);
co[cha[x].t]=cha[x].v;
if(ww) change(cha[x].t);
} int lca(int u,int v){
if(dep[u]<dep[v]) swap(u,v); int t=0;
for(int now=1;now<dep[u];now<<=1) t++;t--;
for(int i=t;i>=0;i--)
if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
if(u==v) return u;
for(int i=t;i>=0;i--)
if(fa[u][i]!=fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
} void query(int u,int v){
int g=lca(u,v);
while(u!=g) change(u),u=fa[u][0];
while(v!=g) change(v),v=fa[v][0];
} int main(){
File("a");
n=getint(); m=getint(); q=getint();
for(int i=1;i<=m;i++) V[i]=getint();
for(int i=1;i<=n;i++) sw[i]=sw[i-1]+getint();
for(int i=1;i<n;i++) link(getint(),getint());
sizblo=pow(n,2.0/3.0)+1; dfs(1,0); tt=0;
for(int i=1;i<=n;i++) if(!fr[i]) fr[i]=nblo;
for(int i=1;i<=n;i++) co[i]=num[i]=getint();
for(int i=1,ty;i<=q;i++){
ty=getint(); s[i].t=++tt;
s[i].u=getint(); s[i].v=getint();
if(!ty){
cha[i].u=num[s[i].u]; cha[i].t=s[i].u;
cha[i].v=num[s[i].u]=s[i].v;
cha[i+q]=cha[i]; swap(cha[i+q].u,cha[i+q].v);
}
else if(fr[s[i].u]>fr[s[i].v]) swap(s[i].u,s[i].v);
}
for(int i=1;i<=n;i++) num[i]=0;
sort(s+1,s+q+1);
for(int i=1,nt=0,u=1,v=1,g;i<=q;i++){
if(cha[s[i].t].t) continue;
g=lca(s[i].u,s[i].v);
while(nt<s[i].t) timego(++nt);
while(nt>s[i].t) timego(q+(nt--));
query(u,s[i].u); query(v,s[i].v);
change(g); ans[s[i].t]=now; change(g);
u=s[i].u; v=s[i].v;
}
for(int i=1;i<=q;i++)
if(!cha[i].t) printf("%lld\n",ans[i]);
return 0;
}