bzoj 1858: [Scoi2010]序列操作

时间:2022-10-01 16:40:58

1858: [Scoi2010]序列操作

Time Limit: 10 Sec  Memory Limit: 64 MB

线段树,对于每个区间需要分别维护左右和中间的1和0连续个数,并在op=4时特殊处理一下。

Description

lxhgww最近收到了一个01序列,序列里面包含了n个数,这些数要么是0,要么是1,现在对于这个序列有五种变换操作和询问操作:
   0 a b 把[a, b]区间内的所有数全变成0;
  1 a b 把[a, b]区间内的所有数全变成1;
  2 a b 把[a,b]区间内的所有数全部取反,也就是说把所有的0变成1,把所有的1变成0;
  3 a b 询问[a, b]区间内总共有多少个1;
  4 a b 询问[a, b]区间内最多有多少个连续的1。
对于每一种询问操作,lxhgww都需要给出回答,聪明的程序员们,你们能帮助他吗?

Input

输入数据第一行包括2个数,n和m,分别表示序列的长度和操作数目 第二行包括n个数,表示序列的初始状态 接下来m行,每行3个数,op, a, b,(0<=op<=4,0<=a<=b<n)表示对于区间[a, b]执行标号为op的操作。

Output

对于每一个询问操作,输出一行,包括1个数,表示其对应的答案

Sample Input

10 10
0 0 0 1 1 0 1 0 1 1
1 0 2
3 0 5
2 2 2
4 0 4
0 3 6
2 3 7
4 2 8
1 0 5
0 5 6
3 3 9

Sample Output

5
2
6
5

HINT

对于30%的数据,1<=n, m<=1000
对于100%的数据,1<=n, m<=100000

Source

#include<cstdio>
#define M 100010
inline int max(int a,int b){return a>b?a:b;}
struct tree{int l,r,s,lazy,sum,ml1,mr1,mm1,ml0,mr0,mm0;}tr[M*];
int a[M],n,m,x,y,v,op;
inline int read()
{
int tmp=;
char ch=getchar();
while(ch<''||ch>'') ch=getchar();
while(ch>=''&&ch<=''){tmp=tmp*+ch-'';ch=getchar();}
return tmp;
}
inline void pu(int p)
{
int p1=p<<,p2=p<<|;
tr[p].sum=tr[p1].sum+tr[p2].sum;
tr[p].mm1=tr[p1].mr1+tr[p2].ml1;
tr[p].mm0=tr[p1].mr0+tr[p2].ml0;
if(tr[p1].mr1==tr[p1].s) tr[p].ml1=tr[p1].mr1+tr[p2].ml1;
else tr[p].ml1=tr[p1].ml1;
if(tr[p1].mr0==tr[p1].s) tr[p].ml0=tr[p1].mr0+tr[p2].ml0;
else tr[p].ml0=tr[p1].ml0;
if(tr[p2].mr1==tr[p2].s) tr[p].mr1=tr[p2].mr1+tr[p1].mr1;
else tr[p].mr1=tr[p2].mr1;
if(tr[p2].mr0==tr[p2].s) tr[p].mr0=tr[p2].mr0+tr[p1].mr0;
else tr[p].mr0=tr[p2].mr0;
tr[p].mm1=max(tr[p].mm1,max(tr[p1].mm1,tr[p2].mm1));
tr[p].mm0=max(tr[p].mm0,max(tr[p1].mm0,tr[p2].mm0));
}
inline void change(int p)
{
int t1=tr[p].ml0,t2=tr[p].mr0,t3=tr[p].mm0;
tr[p].ml0=tr[p].ml1;
tr[p].ml1=t1;
tr[p].mr0=tr[p].mr1;
tr[p].mr1=t2;
tr[p].mm0=tr[p].mm1;
tr[p].mm1=t3;
}
inline void make(int l,int r,int p)
{
tr[p].l=l,tr[p].r=r,tr[p].s=tr[p].r-tr[p].l+,tr[p].lazy=-;
if(l==r)
{
tr[p].sum=a[l];
tr[p].ml1=tr[p].mm1=tr[p].mr1=a[l];
tr[p].ml0=tr[p].mm0=tr[p].mr0=a[l]^;
return;
}
int mid=(l+r)>>;
make(l,mid,p<<);
make(mid+,r,p<<|);
pu(p);
}
inline void pd(int p)
{
int lz=tr[p].lazy,p1=p<<,p2=p<<|;
if(lz==)
{
tr[p1].lazy=tr[p2].lazy=;
tr[p1].sum=;
tr[p2].sum=;
tr[p1].ml1=tr[p1].mr1=tr[p1].mm1=;
tr[p2].ml1=tr[p2].mr1=tr[p2].mm1=;
tr[p1].ml0=tr[p1].mr0=tr[p1].mm0=tr[p1].s;
tr[p2].ml0=tr[p2].mr0=tr[p2].mm0=tr[p2].s;
}
else if(lz==)
{
tr[p1].lazy=tr[p2].lazy=;
tr[p1].sum=tr[p1].s;
tr[p2].sum=tr[p2].s;
tr[p1].ml1=tr[p1].mr1=tr[p1].mm1=tr[p1].s;
tr[p2].ml1=tr[p2].mr1=tr[p2].mm1=tr[p2].s;
tr[p1].ml0=tr[p1].mr0=tr[p1].mm0=;
tr[p2].ml0=tr[p2].mr0=tr[p2].mm0=;
}
else if(lz==)
{
if(tr[p1].lazy==) tr[p1].lazy=;
else if(tr[p1].lazy==) tr[p1].lazy=;
else if(tr[p1].lazy==) tr[p1].lazy=-;
else tr[p1].lazy=;
if(tr[p2].lazy==) tr[p2].lazy=;
else if(tr[p2].lazy==) tr[p2].lazy=;
else if(tr[p2].lazy==) tr[p2].lazy=-;
else tr[p2].lazy=;
tr[p1].sum=tr[p1].s-tr[p1].sum;
tr[p2].sum=tr[p2].s-tr[p2].sum;
change(p1);change(p2);
}
tr[p].lazy=-;
}
inline int find1(int l,int r,int p)
{
pd(p);
if(tr[p].l==l&&tr[p].r==r) return tr[p].sum;
int mid=(tr[p].l+tr[p].r)>>;
if(mid>=r) return find1(l,r,p<<);
else if(mid<l) return find1(l,r,p<<|);
else return find1(l,mid,p<<)+find1(mid+,r,p<<|);
}
inline int findl(int l,int r,int p)
{
pd(p);
if(tr[p].ml1+l>r) return r-l+;
else return tr[p].ml1;
}
inline int findr(int l,int r,int p)
{
pd(p);
if(l+tr[p].mr1>r) return r-l+;
else return tr[p].mr1;
}
inline int find2(int l,int r,int p)
{
pd(p);
if(tr[p].l==l&&tr[p].r==r) return max(tr[p].ml1,max(tr[p].mm1,tr[p].mr1));
int mid=(tr[p].l+tr[p].r)>>;
if(mid>=r) return find2(l,r,p<<);
else if(mid<l) return find2(l,r,p<<|);
else return max(max(find2(l,mid,p<<),find2(mid+,r,p<<|)),findl(mid+,r,p<<|)+findr(l,mid,p<<));
}
inline void xg(int l,int r,int c,int p)
{
pd(p);
if(tr[p].l==l&&tr[p].r==r)
{
tr[p].lazy=c;
if(c==)
{
tr[p].ml1=tr[p].mr1=tr[p].mm1=;
tr[p].ml0=tr[p].mr0=tr[p].mm0=tr[p].s;
tr[p].sum=;
}
else if(c==)
{
tr[p].ml0=tr[p].mr0=tr[p].mm0=;
tr[p].ml1=tr[p].mr1=tr[p].mm1=tr[p].s;
tr[p].sum=tr[p].s;
}
else
{
tr[p].sum=tr[p].s-tr[p].sum;
change(p);
}
return;
}
int mid=(tr[p].l+tr[p].r)>>;
if(mid>=r) xg(l,r,c,p<<);
else if(mid<l) xg(l,r,c,p<<|);
else
{
xg(l,mid,c,p<<);
xg(mid+,r,c,p<<|);
}
pu(p);
}
int main()
{
n=read(),m=read();
for(int i=;i<=n;i++) a[i]=read();
make(,n,);
for(int i=;i<m;i++)
{
op=read();x=read();y=read();
if(op<) xg(x+,y+,op,);
else if(op==) printf("%d\n",find1(x+,y+,));
else printf("%d\n",find2(x+,y+,));
}
return ;
}