树链剖分(+线段树)(codevs4633)

时间:2021-06-16 12:43:20
type node=^link;
link=record
des:longint;
next:node;
end; type seg=record
z,y,lc,rc,toadd,sum:longint;
end; var
n,tot,i,t1,t2,q,a,b,c:longint;
p:node;
son,siz,dep,fa,num,top:array[..] of longint;
tr:array[..] of seg;
nd:array[..] of node; function max(a,b:longint):longint;
begin
if a>b then exit(a) else exit(b);
end; function min(a,b:longint):longint;
begin
if a>b then exit(b) else exit(a);
end; procedure dfs1(po:longint);
var
p:node;
begin
siz[po]:=;son[po]:=;
p:=nd[po];
while p<>nil do
begin
if dep[p^.des]= then
begin
dep[p^.des]:=dep[po]+;
fa[p^.des]:=po;
dfs1(p^.des);
if siz[p^.des]>siz[son[po]] then
son[po]:=p^.des;
siz[po]:=siz[po]+siz[p^.des];
end;
p:=p^.next;
end;
end;//寻找非叶子结点中儿子siz最大,记录在son中 procedure dfs2(po,tp:longint);
var
p:node;
begin
inc(tot);
num[po]:=tot;
top[po]:=tp;
if son[po]<> then
dfs2(son[po],tp); p:=nd[po];
while p<>nil do
begin
if (p^.des<>son[po]) and (p^.des<>fa[po]) then dfs2(p^.des,p^.des);
p:=p^.next;
end;
end;//将重边练成重链,num记录树上的点哈希到线段树上的结果 procedure buildtree(l,r:longint);
var
t:longint;
begin
inc(tot);
tr[tot].sum:=;tr[tot].toadd:=;
tr[tot].z:=l;tr[tot].y:=r;
t:=tot;
if l=r then exit else
begin
tr[t].lc:=tot+;
buildtree(l,(l+r) div );
tr[t].rc:=tot+;
buildtree((l+r) div +,r);
end;
end;//建线段树 procedure add(po,l,r,k:longint);
var
mid:longint;
begin
if tr[po].toadd<> then
begin
tr[po].sum:=tr[po].sum+(tr[po].y-tr[po].z+)*tr[po].toadd;
tr[tr[po].lc].toadd:=tr[tr[po].lc].toadd+tr[po].toadd;
tr[tr[po].rc].toadd:=tr[tr[po].rc].toadd+tr[po].toadd;
tr[po].toadd:=;
end; mid:=(tr[po].z+tr[po].y) div ;
tr[po].sum:=tr[po].sum+(r-l+)*k;
if (l=tr[po].z) and (r=tr[po].y) then
begin
tr[tr[po].lc].toadd:=tr[tr[po].lc].toadd+k;
tr[tr[po].rc].toadd:=tr[tr[po].rc].toadd+k;
exit;
end else
begin
if mid>=l then add(tr[po].lc,l,min(mid,r),k);
if r>mid then add(tr[po].rc,max(mid+,l),r,k);
end;
end;//线段树加 function ans(po,l,r:longint):longint;
var
mid:longint;
begin
if tr[po].toadd<> then
begin
tr[po].sum:=tr[po].sum+(tr[po].y-tr[po].z+)*tr[po].toadd;
tr[tr[po].lc].toadd:=tr[tr[po].lc].toadd+tr[po].toadd;
tr[tr[po].rc].toadd:=tr[tr[po].rc].toadd+tr[po].toadd;
tr[po].toadd:=;
end; mid:=(tr[po].z+tr[po].y) div ;
if (l=tr[po].z) and (r=tr[po].y) then
exit(tr[po].sum) else
begin
ans:=;
if mid>=l then ans:=ans+ans(tr[po].lc,l,min(mid,r));
if r>mid then ans:=ans+ans(tr[po].rc,max(mid+,l),r);
end;
end;//线段树求和 procedure plus(b,c:longint);
begin
while top[b]<>top[c] do
begin
if dep[top[b]]<dep[top[c]] then
begin
add(,num[top[c]],num[c],);
c:=fa[top[c]];
end
else
begin
add(,num[top[b]],num[b],);
b:=fa[top[b]];
end;
end;
if num[b]<num[c] then add(,num[b],num[c],) else add(,num[c],num[b],);
end;//通过重链寻找被修改的区间 function query(b,c:longint):longint;
begin
query:=;
while top[b]<>top[c] do
begin
if dep[top[b]]<dep[top[c]] then
begin
query:=query+ans(,num[top[c]],num[c]);
c:=fa[top[c]];
end
else
begin
query:=query+ans(,num[top[b]],num[b]);
b:=fa[top[b]];
end;
end; if num[b]<num[c] then query:=query+ans(,num[b],num[c]) else query:=query+ans(,num[c],num[b]);
end;//通过重链寻找被求和的区间 begin read(n); for i:= to n- do
begin
read(t1,t2);
new(p);
p^.des:=t2;p^.next:=nd[t1];nd[t1]:=p;
new(p);
p^.des:=t1;p^.next:=nd[t2];nd[t2]:=p;
end; dep[]:=;
dfs1(); dfs2(,); tot:=;
buildtree(,n); read(q);
for i:= to q do
begin
read(a,b,c); if a= then plus(b,c); if a= then writeln(query(b,c));
end;
end.

————————————————————————————————————————————————————————————————

c++(BZOJ1036)

#include <cstdio>
#include <iostream>
#define LL long long
using namespace std; int next[],des[],nd[],bt[],son[],maxi[];
int fa[],dep[],size[],id[],top[],a[],revid[];
int cnt,n,q; struct node{
int l,r,lc,rc,maxi,sum;
}tr[]; void swp(int &x,int &y){
int t=x;x=y;y=t;
} void addedge(int x,int y){
next[++cnt]=nd[x];des[cnt]=y;nd[x]=cnt;
next[++cnt]=nd[y];des[cnt]=x;nd[y]=cnt;
} void dfs1(int po){
bt[po]=;
son[po]=-;maxi[po]=-;
size[po]=;
for (int p=nd[po];p!=-;p=next[p])
if (bt[des[p]]==){
fa[des[p]]=po;dep[des[p]]=dep[po]+;
dfs1(des[p]);
size[po]+=size[des[p]];
if (size[des[p]]>maxi[po]){
maxi[po]=size[des[p]];
son[po]=des[p];
}
}
} void dfs2(int po,int tp){
id[po]=++cnt;top[po]=tp;
if (son[po]==-) return; dfs2(son[po],tp);
for (int p=nd[po];p!=-;p=next[p])
if(des[p]!=fa[po]&&des[p]!=son[po]) dfs2(des[p],des[p]);
} void update(int po){
tr[po].sum=tr[tr[po].lc].sum+tr[tr[po].rc].sum;
tr[po].maxi=max(tr[tr[po].lc].maxi,tr[tr[po].rc].maxi);
} void build(int l,int r){
tr[++cnt].l=l;tr[cnt].r=r;
if (l==r) {tr[cnt].sum=tr[cnt].maxi=a[revid[l]];return;} int t=cnt,mid=(l+r)>>;
tr[t].lc=cnt+;
build(l,mid);
tr[t].rc=cnt+;
build(mid+,r);
update(t);
} void edi(int po,int targ){
if (tr[po].l==tr[po].r) {tr[po].sum=tr[po].maxi=a[targ];return;} int mid=(tr[po].l+tr[po].r>>);
if (targ<=mid) edi(tr[po].lc,targ);else edi(tr[po].rc,targ);
update(po);
} int getmax(int po,int l,int r){
if (l==tr[po].l&&r==tr[po].r) return(tr[po].maxi);
int mid=(tr[po].l+tr[po].r)>>; int ret=-1e9;
if (l<=mid) ret=max(ret,getmax(tr[po].lc,l,min(mid,r)));
if (r>mid) ret=max(ret,getmax(tr[po].rc,max(mid+,l),r));
return(ret);
} void QMAX(int x,int y){
int ans=-1e9;
while (top[x]!=top[y]){
if (dep[top[x]]<dep[top[y]]) swp(x,y);
ans=max(ans,getmax(,id[top[x]],id[x]));
x=fa[top[x]];
}
if (dep[x]<dep[y]) swp(x,y);
ans=max(ans,getmax(,id[y],id[x]));
printf("%d\n",ans);
} int getsum(int po,int l,int r){
if (l==tr[po].l&&r==tr[po].r) return(tr[po].sum);
int mid=(tr[po].l+tr[po].r)>>; int ret=;
if (l<=mid) ret+=getsum(tr[po].lc,l,min(mid,r));
if (r>mid) ret+=getsum(tr[po].rc,max(mid+,l),r);
return(ret);
} void QSUM(int x,int y){
int ans=;
while (top[x]!=top[y]){
if (dep[top[x]]<dep[top[y]]) swp(x,y);
ans+=getsum(,id[top[x]],id[x]);
x=fa[top[x]];
}
if (dep[x]<dep[y]) swp(x,y);
ans+=getsum(,id[y],id[x]);
printf("%d\n",ans);
} int main(){
scanf("%d",&n); for (int i=;i<=n;i++) nd[i]=-;
for (int i=;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
addedge(x,y);
} dep[]=;
dfs1(); cnt=;
dfs2(,);
for (int i=;i<=n;i++) revid[id[i]]=i; for (int i=;i<=n;i++) scanf("%d",&a[i]);
cnt=;
build(,n); scanf("%d",&q);
char st[];
for (int i=;i<=q;i++){
scanf("%s",&st);
int x,y;
scanf("%d%d",&x,&y); if (st[]=='M') QMAX(x,y);
if (st[]=='S') QSUM(x,y);
if (st[]=='H') a[id[x]]=y,edi(,id[x]);
}
}

——————————————————————————————————

树链剖分可对每条链单独建立线段树以减小常数

#include <cstdio>
#include <iostream>
#define LL long long
using namespace std; int next[],des[],nd[],cnt,size[],b[],fa[],dep[],son[];
int id[],rev[],top[],n,q,fr[],to[],root[],maxid[];
LL len[];
LL num[]; struct treenode{
int l,r,lc,rc;
LL num;
}tr[]; void addedge(int x,int y,LL num){
next[++cnt]=nd[x];des[cnt]=y;len[cnt]=num;nd[x]=cnt;
next[++cnt]=nd[y];des[cnt]=x;len[cnt]=num;nd[y]=cnt;
} void dfs1(int po){
size[po]=;b[po]=;
int maxi=-;
for (int p=nd[po];p!=-;p=next[p])
if (b[des[p]]==){
num[des[p]]=len[p];fa[des[p]]=po;
dep[des[p]]=dep[po]+;
dfs1(des[p]);
if (size[des[p]]>maxi){
maxi=size[des[p]];
son[po]=des[p];
}
size[po]+=size[des[p]];
}
} void dfs2(int po,int tp){
id[po]=++cnt;rev[cnt]=po;top[po]=tp; if (son[po]) dfs2(son[po],tp);
for (int p=nd[po];p!=-;p=next[p])
if (des[p]!=fa[po]&&des[p]!=son[po])
dfs2(des[p],des[p]);
} void update(LL &a,LL b,LL c){
if (b==-||c==-){
a=-;return;
}
if (1e18/b<c){
a=-;return;
}
a=b*c;
} void build(int l,int r){
tr[++cnt].l=l;tr[cnt].r=r;
if (l==r){
tr[cnt].num=num[rev[l]];return;
} int mid=(l+r)>>,t=cnt;
tr[t].lc=cnt+;
build(l,mid);
tr[t].rc=cnt+;
build(mid+,r);
update(tr[t].num,tr[tr[t].lc].num,tr[tr[t].rc].num);
} void edi(int po,int tar,LL num){
if (tr[po].l==tr[po].r) {tr[po].num=num;return;} int mid=(tr[po].l+tr[po].r)>>;
if (tar<=mid) edi(tr[po].lc,tar,num);else
edi(tr[po].rc,tar,num);
update(tr[po].num,tr[tr[po].lc].num,tr[tr[po].rc].num);
} LL getnum(int po,int l,int r){
LL ret=;
if (tr[po].l==l&&tr[po].r==r) return(tr[po].num); int mid=(tr[po].l+tr[po].r)>>;
if (l<=mid) update(ret,ret,getnum(tr[po].lc,l,min(mid,r)));
if (r>mid) update(ret,ret,getnum(tr[po].rc,max(mid+,l),r));
return(ret);
} LL query(int x,int y){
LL ret=;
while (top[x]!=top[y]){
if (dep[top[x]]<dep[top[y]]){
int t=x;x=y;y=t;
}
LL t=getnum(root[top[x]],id[top[x]],id[x]);
update(ret,ret,t);x=fa[top[x]];
}
if (dep[x]<dep[y]){
int t=x;x=y;y=t;
}
if (x==y) return(ret);
LL t=getnum(root[top[x]],id[son[y]],id[x]);
update(ret,ret,t);
return(ret);
} int main(){
scanf("%d%d",&n,&q);
for (int i=;i<=n;i++) nd[i]=-;
for (int i=;i<n;i++){
int t1,t2,t3;
scanf("%d%d%lld",&fr[i],&to[i],&t3);
addedge(fr[i],to[i],t3);
} dep[]=;
dfs1();
cnt=;
dfs2(,);
cnt=;
for (int i=;i<=n;i++) maxid[top[i]]=max(maxid[top[i]],id[i]);
for (int i=;i<=n;i++) if (i==top[i]){
root[i]=cnt+;build(id[i],maxid[i]);
} for (int i=;i<=q;i++){
int typ;
scanf("%d",&typ); if (typ==){
int x,y;LL v;
scanf("%d%d%lld",&x,&y,&v);
LL t=query(x,y);
if (t==-) printf("0\n");else printf("%lld\n",v/t);
} if (typ==){
int li;LL v;
scanf("%d%lld",&li,&v);
if (fa[fr[li]]==to[li]){
int t=fr[li];to[li]=fr[li];fr[li]=t;
}
edi(root[top[to[li]]],id[to[li]],v);
}
}
}