[BZOJ2850]巧克力王国(kd-tree)

时间:2022-12-17 16:58:36

题目描述

传送门

题解

维护一棵2dkd-tree,两维分别是x和y,维护这两维的最大最小值,再维护h的子树和sum
对于每一次询问暴力寻找合法的子树,如果子树的最大值<=c的话直接加,否则如果存在<=c的话,就再到子树里暴力
需要注意的一点是因为有负数,所以不一定是最大的乘起来最大,所以四个都要判断一下

代码

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
#define LL long long
#define N 50005

int n,m,root,cmpd;
LL x,y,h,a,b,c,ans;
struct data
{
    int l,r;
    LL val,sum;
    LL d[2],mn[2],mx[2];
}tr[N];

void update(int x)
{
    int l=tr[x].l,r=tr[x].r;
    tr[x].sum=tr[x].val;
    if (l)
    {
        tr[x].mx[0]=max(tr[x].mx[0],tr[l].mx[0]);
        tr[x].mn[0]=min(tr[x].mn[0],tr[l].mn[0]);
        tr[x].mx[1]=max(tr[x].mx[1],tr[l].mx[1]);
        tr[x].mn[1]=min(tr[x].mn[1],tr[l].mn[1]);
        tr[x].sum+=tr[l].sum;
    }
    if (r)
    {
        tr[x].mx[0]=max(tr[x].mx[0],tr[r].mx[0]);
        tr[x].mn[0]=min(tr[x].mn[0],tr[r].mn[0]);
        tr[x].mx[1]=max(tr[x].mx[1],tr[r].mx[1]);
        tr[x].mn[1]=min(tr[x].mn[1],tr[r].mn[1]);
        tr[x].sum+=tr[r].sum;
    }
}
int cmp(data a,data b)
{
    return a.d[cmpd]<b.d[cmpd]||a.d[cmpd]==b.d[cmpd]&&a.d[cmpd^1]<b.d[cmpd^1];
}
int build(int l,int r,int d)
{
    int mid=(l+r)>>1;
    cmpd=d;
    nth_element(tr+l,tr+mid,tr+r+1,cmp);
    if (l<mid) tr[mid].l=build(l,mid-1,d^1);
    if (mid<r) tr[mid].r=build(mid+1,r,d^1);
    update(mid);
    return mid;
}
bool ok(int id)
{
    return (tr[id].d[0]*a+tr[id].d[1]*b<c);
}
int check(int id)
{
    int ans=0;
    if (tr[id].mx[0]*a+tr[id].mx[1]*b<c) ++ans;
    if (tr[id].mx[0]*a+tr[id].mn[1]*b<c) ++ans;
    if (tr[id].mn[0]*a+tr[id].mx[1]*b<c) ++ans;
    if (tr[id].mn[0]*a+tr[id].mn[1]*b<c) ++ans;
    return ans;
}
void query(int now)
{
    if (ok(now)) ans+=tr[now].val;
    if (tr[now].l)
    {
        int t=check(tr[now].l);
        if (t==4) ans+=tr[tr[now].l].sum;
        else if (t) query(tr[now].l);
    }
    if (tr[now].r)
    {
        int t=check(tr[now].r);
        if (t==4) ans+=tr[tr[now].r].sum;
        else if (t) query(tr[now].r);
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;++i)
    {
        scanf("%lld%lld%lld",&x,&y,&h);
        tr[i].mx[0]=tr[i].mn[0]=tr[i].d[0]=x;
        tr[i].mx[1]=tr[i].mn[1]=tr[i].d[1]=y;
        tr[i].val=tr[i].sum=h;
    }
    root=build(1,n,0);
    for (int i=1;i<=m;++i)
    {
        scanf("%lld%lld%lld",&a,&b,&c);
        ans=0LL;
        query(root);
        printf("%lld\n",ans);
    }
}