bzoj 2243 [SDOI2011]染色(树链剖分+线段树合并)

时间:2023-03-08 16:57:06
bzoj 2243 [SDOI2011]染色(树链剖分+线段树合并)

【bzoj2243】[SDOI2011]染色

2017年10月20日

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c;

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。

请你写一个程序依次完成这m个操作。

Input

第一行包含2个整数n和m,分别表示节点数和操作数;

第二行包含n个正整数表示n个节点的初始颜色

下面 行每行包含两个整数x和y,表示xy之间有一条无向边。

下面 行每行描述一个操作:

“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;

“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

Sample Input

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

Sample Output

3
1
2

HINT

数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。、

题解:

题目很好理解,它不是边染色,而是点染色,这个性质是比较好的,边染色还需要裂点。

看的题目就可以想到这是树链剖分模板题吧,套个裸的线段树合并,其实没有什么合并的

东西,发现一段线段的不同颜色,那么就需要记录左端点和右端点颜色,如果左区间右端

点和右区间左端颜色一样,那么总颜色-1,这个比较好理解的吧,然后记录一个该区间总

颜色数,就可以统计了。

程序比较结构化
两个dfs预处理,lca,线段树,询问处理+更新处理,就ok了,代码比较清晰。

 #include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#define N 100007
using namespace std; int n,m,sz=;
int cnt,head[N],next[N*],rea[N*];
int a[N];
int fa[N][],size[N],pos[N],bel[N],deep[N];
char ch[];
struct Node
{
int lc,rc,tag,num;
}tr[N*]; void add(int u,int v){next[++cnt]=head[u],head[u]=cnt,rea[cnt]=v;}
void dfs_init(int u)
{
size[u]=;
for (int i=;(<<i)<=deep[u];i++)
fa[u][i]=fa[fa[u][i-]][i-];
for (int i=head[u];i!=-;i=next[i])
{
int v=rea[i];
if (v==fa[u][]) continue;
deep[v]=deep[u]+;
fa[v][]=u;
dfs_init(v);
size[u]+=size[v];
}
}
void dfs_make(int u,int chain)
{
int k=;
pos[u]=++sz,bel[u]=chain;
for (int i=head[u];i!=-;i=next[i])
{
int v=rea[i];
if (deep[v]>deep[u]&&size[v]>size[k]) k=v;
}
if (k==) return;
dfs_make(k,chain);
for (int i=head[u];i!=-;i=next[i])
{
int v=rea[i];
if (deep[v]>deep[u]&&v!=k) dfs_make(v,v);
}
}
int lca(int a,int b)
{
if (deep[a]<deep[b]) swap(a,b);
int i;
for (i=;(<<i)<=deep[a];i++);
i--;
for (int j=i;j>=;j--)
if (deep[a]-(<<j)>=deep[b]) a=fa[a][j];
if (a==b) return a;
for (int j=i;j>=;j--)
if (fa[a][j]!=fa[b][j]) a=fa[a][j],b=fa[b][j];
return fa[a][];
}
void updata_down(int l,int r,int p)
{
int tag=tr[p].tag;tr[p].tag=-;
if (tag==-||l==r) return;
tr[p<<].num=tr[p<<|].num=;
tr[p<<].tag=tr[p<<|].tag=tag;
tr[p<<].lc=tr[p<<].rc=tag;
tr[p<<|].lc=tr[p<<|].rc=tag;
}
void updata_up(int l,int r,int p)
{
tr[p].lc=tr[p<<].lc,tr[p].rc=tr[p<<|].rc;
tr[p].num=tr[p<<].num+tr[p<<|].num;
if (tr[p<<].rc==tr[p<<|].lc) tr[p].num--;
}
void change(int l,int r,int p,int x,int y,int z)
{
updata_down(l,r,p);
if (l==x&&y==r)
{tr[p].num=,tr[p].lc=tr[p].rc=tr[p].tag=z;return;}
int mid=(l+r)>>;
if (y<=mid) change(l,mid,p<<,x,y,z);
else if (x>mid) change(mid+,r,p<<|,x,y,z);
else change(l,mid,p<<,x,mid,z),change(mid+,r,p<<|,mid+,y,z);
updata_up(l,r,p);
}
int query(int l,int r,int p,int x,int y)
{
updata_down(l,r,p);
if (l==x&&y==r) return tr[p].num;
int mid=(l+r)>>,res;
if (y<=mid) res=query(l,mid,p<<,x,y);
else if (x>mid) res=query(mid+,r,p<<|,x,y);
else
{
res=query(l,mid,p<<,x,mid)+query(mid+,r,p<<|,mid+,y);
if (tr[p<<].rc==tr[p<<|].lc) res--;
}
return res;
}
int find(int l,int r,int p,int x)
{
updata_down(l,r,p);
if (l==r) return tr[p].lc;
int mid=(l+r)>>;
if (x<=mid) return find(l,mid,p<<,x);
else return find(mid+,r,p<<|,x);
}
int solvequery(int x,int fq)
{
int res=;
while(bel[x]!=bel[fq])
{
res+=query(,n,,pos[bel[x]],pos[x]);
if (find(,n,,pos[bel[x]])==find(,n,,pos[fa[bel[x]][]])) res--;
x=fa[bel[x]][];
}
res+=query(,n,,pos[fq],pos[x]);
return res;
}
void solvechange(int x,int fq,int z)
{
while(bel[x]!=bel[fq])
{
change(,n,,pos[bel[x]],pos[x],z);
x=fa[bel[x]][];
}
change(,n,,pos[fq],pos[x],z);
}
int main()
{
memset(head,-,sizeof(head));tr[].tag=-;
scanf("%d%d",&n,&m);
for(int i=;i<=n;i++)
scanf("%d",&a[i]);
int x,y,z;
for (int i=;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
dfs_init();
dfs_make(,);
for (int i=;i<=n;i++)
change(,n,,pos[i],pos[i],a[i]);
//==============================================================
for (int i=;i<=m;i++)
{
scanf("%s",ch);
if (ch[]=='Q')
{
scanf("%d%d",&x,&y);
int par=lca(x,y);
printf("%d\n",solvequery(x,par)+solvequery(y,par)-);
}
else
{
scanf("%d%d%d",&x,&y,&z);
int par=lca(x,y);
solvechange(x,par,z),solvechange(y,par,z);
}
}
}