bzoj2243[SDOI2011]染色 树链剖分+线段树

时间:2023-03-08 16:57:11

2243: [SDOI2011]染色

Time Limit: 20 Sec  Memory Limit: 512 MB
Submit: 9012  Solved: 3375
[Submit][Status][Discuss]

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c;

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),

如“112221”由3段组成:“11”、“222”和“1”。

请你写一个程序依次完成这m个操作。

Input

第一行包含2个整数n和m,分别表示节点数和操作数;

第二行包含n个正整数表示n个节点的初始颜色

下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。

下面 行每行描述一个操作:

“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;

“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

Sample Input

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

Sample Output

3
1
2

HINT

数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。

Source

第一轮day1

树链剖分+线段树
操作还是常规的区间修改 但是要注意区间合并的时候,颜色段数量的合并
在树的路径上合并区间时,有点不好处理区间左右端点,于
是我们把u->v拆成u->lca lca->v
这样区间更新端点就是同向的,即每次查询可以把一端统一地作为更新端点
由于u->lca lca->v左右端点必定重合 所以 最后答案减一

sb的我update忘了pushup调了1h

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define ls u<<1
#define rs ls|1
#define ll long long
#define N 100050
using namespace std;
int n,m,tot,cnt,a[N],fa[N][20],tp[N],siz[N],hd[N];
int dep[N],tid[N],son[N],v[N];
struct edge{int v,next;}e[N<<1];
struct node{int l,r,sum,lz;}t[N<<2];
void adde(int u,int v){
e[++tot].v=v;
e[tot].next=hd[u];
hd[u]=tot;
}
void dfs1(int u,int pre){
dep[u]=dep[pre]+1;fa[u][0]=pre;siz[u]=1;
for(int j=1;(1<<j)<dep[u];j++)
fa[u][j]=fa[fa[u][j-1]][j-1];
for(int i=hd[u];i;i=e[i].next){
int v=e[i].v;
if(v==pre)continue;
dfs1(v,u);siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int anc){
if(!u)return;
tid[u]=++cnt;v[cnt]=a[u];tp[u]=anc;
dfs2(son[u],anc);
for(int i=hd[u];i;i=e[i].next){
int v=e[i].v;
if(v==fa[u][0]||v==son[u])continue;
dfs2(v,v);
}
}
int lca(int x,int y){
if(x==y)return x;
if(dep[x]<dep[y])swap(x,y);
for(int i=18;~i;i--)
if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
if(x==y)return x;
for(int i=18;~i;i--){
if(fa[x][i]==fa[y][i])continue;
x=fa[x][i];y=fa[y][i];
}
return fa[x][0];
} void pushup(int u){
t[u].sum=t[ls].sum+t[rs].sum;
t[u].l=t[ls].l;t[u].r=t[rs].r;
if(t[ls].r==t[rs].l)t[u].sum--;
}
void pushdown(int u){
if(t[u].lz==-1)return;
t[ls].l=t[ls].r=t[ls].lz=t[u].lz;
t[rs].l=t[rs].r=t[rs].lz=t[u].lz;
t[ls].sum=t[rs].sum=1;
t[u].lz=-1;
}
void build(int u,int l,int r){
t[u].lz=-1;
if(l==r){
t[u].l=t[u].r=v[l];
t[u].sum=1;return;
}
int mid=l+r>>1;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(u);
}
node query(int u,int L,int R,int l,int r){
if(l<=L&&R<=r)return t[u];
pushdown(u);node ret;
int mid=L+R>>1,fg=0;
if(l<=mid)ret=query(ls,L,mid,l,r),fg=1;
if(r>mid){
node tmp=query(rs,mid+1,R,l,r);
if(!fg)ret=tmp;
else{
ret.sum+=tmp.sum;
if(ret.r==tmp.l)ret.sum--;
ret.r=tmp.r;
}
}
return ret;
}
void update(int u,int L,int R,int l,int r,int c){
if(l<=L&&R<=r){
t[u].lz=t[u].l=t[u].r=c;
t[u].sum=1;return;
}
pushdown(u);
int mid=L+R>>1;
if(l<=mid)update(ls,L,mid,l,r,c);
if(r>mid)update(rs,mid+1,R,l,r,c);
pushup(u);
}
node jump(int x,int y,int val,int op){
int fx=tp[x],fy=tp[y];node ret;
ret.l=ret.r=ret.sum=0;
while(fx!=fy){
if(op)update(1,1,cnt,tid[fx],tid[x],val);
else{
node tmp=query(1,1,cnt,tid[fx],tid[x]);
ret.sum+=tmp.sum;
if(tmp.r==ret.l)ret.sum--;
ret.l=tmp.l;
}
x=fa[fx][0];fx=tp[x];
}
if(dep[x]>dep[y])swap(x,y);
if(op)update(1,1,cnt,tid[x],tid[y],val);
else{
node tmp=query(1,1,cnt,tid[x],tid[y]);
ret.sum+=tmp.sum;
if(tmp.r==ret.l)ret.sum--;
ret.l=tmp.l;
}
return ret;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<n;i++){
int a,b;
scanf("%d%d",&a,&b);
adde(a,b);adde(b,a);
}
dfs1(1,0);dfs2(1,1);
build(1,1,cnt);
int x,y,c;char s[2];
while(m--){
scanf("%s%d%d",s,&x,&y);
int anc=lca(x,y);
if(s[0]=='C'){
scanf("%d",&c);
jump(x,anc,c,1);
jump(y,anc,c,1);
}
else{
node t1,t2;
t1=jump(x,anc,0,0);
t2=jump(y,anc,0,0);
printf("%d\n",t1.sum+t2.sum-1);
}
}
return 0;
}