[BZOJ4695][picks loves segment tree III]

时间:2022-06-07 12:37:13

【题面】

Description

给定一个长度为 N序列,编号从1 到 N。要求支持下面几种操作:
1.给一个区间[L,R] 加上一个数x 
2.把一个区间[L,R] 里小于x 的数变成x 
3.把一个区间[L,R] 里大于x 的数变成x 
4.求区间[L,R] 的和
5.求区间[L,R] 的最大值
6.求区间[L,R] 的最小值

Input

第一行一个整数 N表示序列长度。
第二行N 个整数Ai 表示初始序列。
第三行一个整数M 表示操作个数。
接下来M 行,每行三或四个整数,第一个整数Tp 表示操作类型,接下来L,R,X 或L,R 表述操作数。
1<=tp<=6,N,M<=5*10^5,|Ai|<=10^8
Tp=1时,|x|<=1000
Tp=2或3时,|x|<=10^8

Output

对于每个4,5,6类型的操作输出一行一个整数表示答案。

Sample Input

2
1 2
2
2 1 2 2
4 1 2

Sample Output

4

以下内容所有权利属于c_sunshine和jiry_2

【解题思路】

对线段树上的每一个区间维护区间最大值mx,这个区间中最大值出 现的次数t,区间次大值se,当然还要维护区间和sum 
现在考虑打上区间取min标记x: 
• 如果mx<=x,那么对sum就没有修改。 
• 如果se<x<mx,那么sum=sum-(mx-x)×t。 
• 如果x<=se<mx,那么… 
如果遇到这种情况,我们分别DFS这个节点的两个孩子,如果当前 DFS的过程中遇到了前两种情况,就直接修改打上标记然后退出,否则就继续DFS。
[BZOJ4695][picks loves segment tree III]
我们来试着写一写这玩意的程序,发现跑的飞快。
实际上这个做法的复杂度是 O(nlogn) 的。 

【复杂度证明】

我们把最大值和它的父节点不同的节点称为关键点,令势函数Φ为线 段树中的关键点个数,显然有0≤Φ≤n 。
[BZOJ4695][picks loves segment tree III]
考虑DFS时的终止节点v,设它的父亲为f,那么有mxv=mxf或者 mxv<mxf,其中第二类点在DFS前是关键点,在DFS后不是关键点。
设第二类点数为A,那么访问的总点数一定是O(Alogn)的。此时我们 花费的时间为O(Alogn),而Φ减少了A。 
考虑所有修改操作,每一次修改的时候影响到了线段树上O(logn)个 节点,最坏的情况下每一个受影响的节点都由非关键点变成了关键点, 那么每一次修改操作都至多使势函数增加O(logn) 
因此势函数的总变化量是O(mlogn)的,所以可以得到每一次DFS的 均摊复杂度是O(log2n)。
由此可以得到这个算法的时间复杂度是O(nlogn+mlog2n)。
虽然目前的复杂度已经比分块优秀许多了,但是依然有些难以接受。
经过实现之后,发现这个算法处理500000的数据在UOJ上只需要0.6s, 实际的运行效率更像O(nlogn)。
我们可以换种方式来考虑:把线段树上每一个节点的最大值看成是区 间取min标记,次大值看成是子树中标记的最大值。既然把最大值看 成是标记,那么一些无用的标记就可以被删去
[BZOJ4695][picks loves segment tree III]
这些标记满足每一个标记的值一定比它子树中的所有标记的值大。因此每一个位置实际的值等价与它到根路径上碰到的第一个标记的值,而上述算法中的DFS过程,相当于是对子树中比当前标记大的标记进 行了回收。
考虑区间加减对区间取min标记的影响,其实就和普通的线段树一样, 我们对所有访问到的节点都进行了一次标记下传。
因为回收标记的时间复杂度不会超过打标记和标记下传的时间复杂度之和,所以就有标记回收(即DFS)的时间复杂度是O(mlogn)的。
因此,上述算法的时间复杂度为 O(mlogn)

【呆马(巨丑)】

#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<iostream>
#define ll long long
using namespace std;
const int N=5*(1e5+1),inf=1e9;
struct st{int max,min,se1,se2,num1,num2,tag,tmx,tmn; ll sum;} T[N<<2];
int n,m,i,x,y,z,tp,a[N];
ll ans;
void modify(int,int,int,int,int,int,int);
void pushdown(int,int,int);
void go(int &x,int y){x+=x==inf||x==-inf?0:y;}
inline int max(int x,int y){int temp=x-y; return y+(temp&(~(temp>>31)));}
inline void read(int &x)
{
    char ch=getchar();
    int f=1;
    for (;ch<'0' || ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (x=0;ch>='0' && ch<='9';ch=getchar()) x=x*10+ch-'0';
    if (f==-1) x=-x;
}
 
bool updmax(int t,int l,int r,int x)
{
    if (x<=T[t].min) return 1;
    if (T[t].se2>x)
    {
        T[t].sum+=(ll)T[t].num2*(x-T[t].min);
        T[t].max=max(T[t].max,T[t].min=x);
        T[t].se1=max(T[t].se1,x);
        if (T[t].max==x) T[t].se1=-inf,T[t].se2=inf,T[t].num1=T[t].num2=r-l+1;
        T[t].tmn=max(T[t].tmx=max(T[t].tmx,x),T[t].tmn);
        return 1;
    }
    return 0;
}
 
bool updmin(int t,int l,int r,int x)
{
    if (x>=T[t].max) return 1;
    if (T[t].se1<x)
    {
        T[t].sum-=(ll)T[t].num1*(T[t].max-x);
        T[t].min=min(T[t].min,T[t].max=x);
        T[t].se2=min(T[t].se2,x);
        if (T[t].min==x) T[t].se1=-inf,T[t].se2=inf,T[t].num1=T[t].num2=r-l+1;
        T[t].tmx=min(T[t].tmn=min(T[t].tmn,x),T[t].tmx);
        return 1;
    }
    return 0;
}
 
void pushup(int t)
{
    int ls=t<<1,rs=ls|1;
    T[t].num1=T[t].num2=0;
    if (T[ls].max>=T[rs].max) T[t].max=T[ls].max,T[t].num1=T[ls].num1;
    if (T[rs].max>=T[ls].max) T[t].max=T[rs].max,T[t].num1+=T[rs].num1;
    if (T[ls].min<=T[rs].min) T[t].min=T[ls].min,T[t].num2=T[ls].num2;
    if (T[rs].min<=T[ls].min) T[t].min=T[rs].min,T[t].num2+=T[rs].num2;
    int se=-inf;
    if (T[ls].max<T[rs].max) se=max(se,T[ls].max);
    else se=max(se,T[ls].se1);
    if (T[rs].max<T[ls].max) se=max(se,T[rs].max);
    else se=max(se,T[rs].se1);
    T[t].se1=se;
    se=inf;
    if (T[ls].min>T[rs].min) se=min(se,T[ls].min);
    else se=min(se,T[ls].se2);
    if (T[rs].min>T[ls].min) se=min(se,T[rs].min);
    else se=min(se,T[rs].se2);
    T[t].se2=se;
    T[t].sum=T[ls].sum+T[rs].sum;
}
 
void pushdown(int t,int l,int r)
{
    if (l==r) return;
    int mid=(l+r)>>1,ls=t<<1,rs=ls|1;
    int x=T[t].tag;
    if (x)
    {
        T[ls].max+=x;
        T[ls].min+=x;
        go(T[ls].se1,x);
        go(T[ls].se2,x);
        T[ls].tmx+=x;
        T[ls].tmn+=x;
        T[rs].tmx+=x;
        T[rs].tmn+=x;
        T[ls].sum+=(mid-l+1)*x;
        T[rs].max+=x;
        T[rs].min+=x;
        go(T[rs].se1,x);
        go(T[rs].se2,x);
        T[rs].sum+=(r-mid)*x;
        T[ls].tag+=x;
        T[rs].tag+=x;
        T[t].tag=0;
    }
    if (T[t].tmx!=-inf)
    {
        updmax(ls,l,mid,T[t].tmx);
        updmax(rs,mid+1,r,T[t].tmx);
        T[t].tmx=-inf;
    }
    if (T[t].tmn!=inf)
    {
        updmin(ls,l,mid,T[t].tmn);
        updmin(rs,mid+1,r,T[t].tmn);
        T[t].tmn=inf;
    }
}
 
void build(int t,int l,int r)
{
    T[t].tmx=-inf,T[t].tmn=inf;
    if (l==r)
    {
        T[t].max=T[t].min=T[t].sum=a[l];
        T[t].num1=T[t].num2=1;
        T[t].se1=-inf,T[t].se2=inf;
        return;
    }
    int mid=(l+r)>>1,ls=t<<1,rs=ls|1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    pushup(t);
}
 
void modify(int t,int l,int r)
{
    if (x<=l && r<=y)
    {
        if (tp==1)
        {
            T[t].max+=z,T[t].min+=z,T[t].tag+=z,T[t].tmx+=z,T[t].tmn+=z;
            T[t].sum+=(r-l+1)*z;
            go(T[t].se1,z),go(T[t].se2,z);
            pushdown(t,l,r);
            return;
        }
        else if (tp==2 && updmax(t,l,r,z)) return;
        else if (tp==3 && updmin(t,l,r,z)) return;
    }
    pushdown(t,l,r);
    int mid=(l+r)>>1,ls=t<<1,rs=ls|1;
    if (x<=mid) modify(ls,l,mid);
    if (y>mid) modify(rs,mid+1,r);
    pushup(t);
}
 
void query(int t,int l,int r)
{
    pushdown(t,l,r);
    if (x<=l && r<=y)
    {
        if (tp==4) ans+=T[t].sum;
        if (tp==5){if ((ll)T[t].max>ans) ans=T[t].max;}
        if (tp==6){if ((ll)T[t].min<ans) ans=T[t].min;}
        return;
    }
    int mid=(l+r)>>1,ls=t<<1,rs=ls|1;
    if (x<=mid) query(ls,l,mid);
    if (y>mid) query(rs,mid+1,r);
}
 
int main()
{
        read(n);
        for (i=1;i<=n;i++) scanf("%d",&a[i]);
        build(1,1,n);
        for (read(m);m;m--)
        {
            read(tp);
            if (tp<=3)
            {
                read(x),read(y),read(z);
                modify(1,1,n);
            }
            else
            {
                if (tp==4) ans=0;
                if (tp==5) ans=-inf;
                if (tp==6) ans=inf;
                read(x),read(y);
                query(1,1,n);
                printf("%lld\n",ans);
            }
        }
}