bzoj3110 [Zjoi2013]K大数查询——线段树套线段树

时间:2024-12-24 22:05:44

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3110

外层权值线段树套内层区间线段树;

之所以外层权值内层区间,是因为区间线段树需要标记下传,所以写在内层比较方便;

然而空间太大了,所以动态开点,大约每个外层线段树的点上有 logn 个内层线段树点;

最开始写的不知为何很快就WA:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const maxn=5e4+,maxm=maxn*;
int n,m,rt[maxn<<],ls[maxm],rs[maxm],cnt;
int a[maxn],b[maxn],tp[maxn],c[maxn],tmp[maxn],tot;
ll sum[maxm],lzy[maxm];//ll lzy 防止计算时爆int
void pushdown(int x,int l,int r)
{
if(!lzy[x])return;
if(!ls[x])ls[x]=++cnt;
if(!rs[x])rs[x]=++cnt;
int mid=((l+r)>>);
sum[ls[x]]+=(mid-l+)*lzy[x]; sum[rs[x]]+=(r-mid)*lzy[x];
lzy[ls[x]]+=lzy[x]; lzy[rs[x]]+=lzy[x];
lzy[x]=;
}
void add(int &x,int l,int r,int L,int R)
{
if(!x)x=++cnt;
if(l>=L&&r<=R){sum[x]+=(r-l+); lzy[x]++; return;}
// pushdown(x,l,r);
int mid=((l+r)>>);
if(mid>=L)add(ls[x],l,mid,L,R);
if(mid<R)add(rs[x],mid+,r,L,R);
sum[x]=sum[ls[x]]+sum[rs[x]];
}
void insert(int x,int l,int r,int tl,int tr,int c)
{
add(rt[x],,n,tl,tr);
if(l==r)return;//
int mid=((l+r)>>);//权值区间
if(c<=mid)insert(x<<,l,mid,tl,tr,c);
else insert(x<<|,mid+,r,tl,tr,c);
}
ll ask(int x,int l,int r,int L,int R)
{
if(!x)return ;//
if(l>=L&&r<=R)return sum[x];
pushdown(x,l,r);
int mid=((l+r)>>); ll ret=;
// if(mid>=L)ret+=ask(x<<1,l,mid,L,R);
// if(mid<R)ret+=ask(x<<1|1,mid+1,r,L,R);
if(mid>=L)ret+=ask(ls[x],l,mid,L,R);
if(mid<R)ret+=ask(rs[x],mid+,r,L,R);//别写串了!
return ret;
}
int query(int x,int l,int r,int L,int R,int k)
{
if(l==r)return l;
ll tmp=ask(rt[x<<|],,n,L,R),mid=((l+r)>>);//查询第k大
if(tmp>=k)return query(x<<|,mid+,r,L,R,k);//
else return query(x<<,l,mid,L,R,k-tmp);//
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=;i<=m;i++)
{
scanf("%d%d%d%d",&tp[i],&a[i],&b[i],&c[i]);
if(tp[i]==)tmp[++tot]=c[i];
}
sort(tmp+,tmp+n+); tot=unique(tmp+,tmp+n+)-tmp-;
for(int i=;i<=m;i++)
{
if(tp[i]==)
{
int tt=lower_bound(tmp+,tmp+tot+,c[i])-tmp;
insert(,,tot,a[i],b[i],tt);
}
else printf("%d\n",tmp[query(,,tot,a[i],b[i],c[i])]);
}
return ;
}

不会改了,所以模仿别人写了个标记永久化,然后WA惨;

因为不太熟悉标记永久化,没注意到更新以及查询时要注意不能重复,改了半天,终于好了...

学到了点标记永久化的细节。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const maxn=5e4+,maxm=maxn*;
int n,m,rt[maxn<<],ls[maxm],rs[maxm],cnt;
int a[maxn],b[maxn],tp[maxn],c[maxn],tmp[maxn],tot,L,R;
ll sum[maxm],lzy[maxm];//ll lzy 防止计算时爆int
//void pushdown(int x,int l,int r)
//{
// if(!lzy[x])return;
// if(!ls[x])ls[x]=++cnt;
// if(!rs[x])rs[x]=++cnt;
// int mid=((l+r)>>1);
// sum[ls[x]]+=(mid-l+1)*lzy[x]; sum[rs[x]]+=(r-mid)*lzy[x];
// lzy[ls[x]]+=lzy[x]; lzy[rs[x]]+=lzy[x];
// lzy[x]=0;
//}
void add(int &x,int l,int r,int L,int R)
{
if(!x)x=++cnt;
if(l==L&&r==R){/*sum[x]+=(r-l+1);*/ lzy[x]++; return;}//==
sum[x]+=(R-L+);
// pushdown(x,l,r);
int mid=((l+r)>>);
// if(mid>=L)add(ls[x],l,mid,L,R);//会重复计算sum!(因为标记永久化)
// if(mid<R)add(rs[x],mid+1,r,L,R);
// sum[x]=sum[ls[x]]+sum[rs[x]];
if(mid<L)add(rs[x],mid+,r,L,R);
else if(mid>=R)add(ls[x],l,mid,L,R);
else add(ls[x],l,mid,L,mid),add(rs[x],mid+,r,mid+,R);
}
ll ask(int x,int l,int r,int L,int R)
{
if(!x)return ;//
ll ret=lzy[x]*(R-L+);
if(l==L&&r==R)return ret+sum[x];//
// pushdown(x,l,r);
int mid=((l+r)>>);
// if(mid>=L)ret+=ask(x<<1,l,mid,L,R);
// if(mid<R)ret+=ask(x<<1|1,mid+1,r,L,R);//别把ls,rs和x<<1,x<<1|1写串了!
// if(mid>=L)ret+=ask(ls[x],l,mid);
// if(mid<R)ret+=ask(rs[x],mid+1,r); //会重复计算!!
// printf("ret=%lld\n",ret);
if(mid<L)return ret+ask(rs[x],mid+,r,L,R);
else if(mid>=R)return ret+ask(ls[x],l,mid,L,R);
else return ret+ask(ls[x],l,mid,L,mid)+ask(rs[x],mid+,r,mid+,R);
}
int query(int x,int l,int r,int k)
{
if(l==r)return l;
ll tmp=ask(rt[x<<|],,n,L,R),mid=((l+r)>>);//查询第k大
if(tmp>=k)return query(x<<|,mid+,r,k);//
else return query(x<<,l,mid,k-tmp);//
}
void insert(int x,int l,int r,int c)
{
add(rt[x],,n,L,R);
if(l==r)return;//
int mid=((l+r)>>);//权值区间
if(c<=mid)insert(x<<,l,mid,c);
else insert(x<<|,mid+,r,c);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=;i<=m;i++)
{
scanf("%d%d%d%d",&tp[i],&a[i],&b[i],&c[i]);
if(tp[i]==)tmp[++tot]=c[i];
}
sort(tmp+,tmp+n+); tot=unique(tmp+,tmp+n+)-tmp-;
for(int i=;i<=m;i++)
{
L=a[i],R=b[i];
if(tp[i]==)
{
int tt=lower_bound(tmp+,tmp+tot+,c[i])-tmp;
insert(,,tot,tt);
}
else printf("%d\n",tmp[query(,,tot,c[i])]);
}
return ;
}