题目:codevs 1228 苹果树
链接:http://codevs.cn/problem/1228/
看了这么多树链剖分的解释,几个小时后总算把树链剖分弄懂了。
树链剖分的功能:快速修改,查询树上的路径。
比如一颗树
首先,我们要把树剖分成树链。定义:
fa[x]是x节点的上一层节点(就是他的爸爸)。
deep[x]是x节点的深度。
num[x]是x节点下面的子节点的数量(包括自己)
son[x]重儿子:一个节点的儿子的num[x]值最大的节点。其他的儿子都是轻儿子。
重链:重儿子连接在一起的路径,比如上图粗线就是重链(叶节点也是重链,只不过它只有一个点)。
重链之间是用一条轻链边连接的。
top[x]是每条重链的根节点,即是上图中的红色点。
tree[x]是数上节点在线段树上的编号
ftree[x]是线段树上节点在原来树的节点号
现在把它放到线段树里,从根节点开始编号为1,沿着重链走,每走到一个节点给它编号(可以用一个topa变量记录下一个编号),重链走完了走轻链。如图所示就给每条边都编上号了。如果边的长度没有,当然也可以把节点放在线段树上。图总的蓝色数字就是这条边在线段树里的位置,形成了区间,如下图。
然后把这个数组组成最终的线段树,就可以控制它的区间了。
可以发现,虽然看上去把树剖分放到线段树上好像打乱了树的顺序,线段树中的点仍然有原来树的影子。比如如果我要访问x节点的子树,那么这个节点的子树的区间就是从tree[x]到tree[x]+num[x]-1(-1是减掉自己这个节点)的区间。
我们可以用2个dfs来把剖分的动作实现。
第一个dfs先实现fa[x],deep[x],num[x]的计算,num要在访问完子树之后计算,见代码:
void dfs1(int x)
{
num[x]++;
for(int i=;i<map[x].size();i++)
{
int dd=map[x][i];
if(dd!=fa[x])
{
fa[dd]=x;
deep[dd]=deep[x]+;
dfs1(dd);
num[x]+=num[dd];
}
}
}
注释:map是STL的vector,用来储存边。
第二个dfs完成son[x],tree[x],ftree[x]的计算,代码如下:
void dfs2(int x)
{
topa++;
ftree[topa]=x;
A[topa]++;
tree[x]=topa;
int zi=,mx=;
for(int i=;i<map[x].size();i++)
{
int dd=map[x][i];
if(num[dd]>mx)
{
mx=num[dd];
zi=dd;
}
}
if(zi!=) dfs2(zi); else return;
son[x]=zi;
for(int i=;i<map[x].size();i++)
{
int dd=map[x][i];
if(dd!=zi) dfs2(dd);
}
}
剖分动作结束,接下来是线段树的事情了。
这里再说一下如何在线段树上操作原树,之前提到过,其实在线段树上也有原来树的结构。
x的子树区间就是tree[x]到tree[x]+num[x]-1。
下面来看一下这道题:codevs 1228 苹果树
这是一个最基本的树链剖分。题目中要求计算一颗子树上有苹果多少颗,改变是点修改。因此只要找到那个节点,子树在线段树上的位置,线段树是维护某区间的苹果树数量,查询操作就是一般的线段树查询。
代码:
#include<cstdio>
#include<vector>
#include<iostream>
using namespace std;
const int maxn=; vector<int> map[maxn];
int fa[maxn],n,deep[maxn],num[maxn],topa,A[maxn],tree[maxn],ftree[maxn],son[maxn],sumv[maxn*],k; void dfs1(int x)
{
num[x]++;
for(int i=;i<map[x].size();i++)
{
int dd=map[x][i];
if(dd!=fa[x])
{
fa[dd]=x;
deep[dd]=deep[x]+;
dfs1(dd);
num[x]+=num[dd];
}
}
} void dfs2(int x)
{
topa++;
ftree[topa]=x;
A[topa]++;
tree[x]=topa;
int zi=,mx=;
for(int i=;i<map[x].size();i++)
{
int dd=map[x][i];
if(num[dd]>mx)
{
mx=num[dd];
zi=dd;
}
}
if(zi!=) dfs2(zi); else return;
son[x]=zi;
for(int i=;i<map[x].size();i++)
{
int dd=map[x][i];
if(dd!=zi) dfs2(dd);
}
} void init(int o,int L,int R)
{
if(L==R) sumv[o]=A[L];
else
{
int M=(L+R)/;
init(o*,L,M);
init(o*+,M+,R);
sumv[o]=sumv[o*]+sumv[o*+];
}
} int y1,y2,p;
void update(int o,int L,int R)
{
if(L==R) sumv[o]=(sumv[o]+)%;
else
{
int M=(L+R)/;
if(p<=M) update(o*,L,M);
else update(o*+,M+,R);
sumv[o]=sumv[o*]+sumv[o*+];
}
} int ans;
void query(int o,int L,int R)
{
if(y1<=L && R<=y2) ans+=sumv[o];
else
{
int M=(L+R)/;
if(y1<=M) query(o*,L,M);
if(y2>M) query(o*+,M+,R);
}
} int main()
{
cin>>n;
for(int i=,x,y;i<=n-;i++)
{
cin>>x>>y;
map[x].push_back(y);
}
deep[]=;
dfs1();
dfs2(); init(,,n); cin>>k;
for(int i=,x;i<=k;i++)
{
char tp;
cin>>tp;
if(tp=='C')
{
cin>>x;
p=tree[x];
update(,,n);
}
else
{
cin>>x;
y1=tree[x];
y2=y1+num[x]-;
ans=;
query(,,n);
cout<<ans<<endl;
}
}
return ;
}