JZOJ 6664. 【2020.05.28省选模拟】最优化

时间:2021-11-09 00:39:46

\(\text{Solution}\)

原题:\(\text{Honorable Mention}\)

一个费用流做法,\(S\)\(2i-1\) 连流量为 \(1\),费用为 \(0\) 的边,\(2i\)\(T\) 连流量为 \(1\),费用为 \(0\) 的边
\(2i-1\)\(2i\) 连流量为 \(1\),费用为 \(a_i\) 的边。然后增广 \(k\) 次即为答案

既然用了费用流模型那么这个关于 \(k\) 的函数自然是凸函数
于是可以考虑一些优化
比如,多组询问想到将区间拆成 \(O(\log n)\) 段线段树上的区间,处理每个区间上的函数值,合并可以做到 \(O(n)\)
也就是要维护凸包,闵科夫斯基和,就可以 \(O(n\log n)\) 预处理凸包了
但这样还是 \(O(nQ)\) 的,仍然是暴力
每个询问有 \(O(\log n)\) 个凸包,合并代价很高
考虑 \(\text{WQS}\) 二分的威力,想想凸包合并时 \(f_{i+j}=f_i+f_j\),又点 \((i,f_i)\) 考虑成 \(f_i=ki+b_i,f_j=kj+b_j\)
那么合并后的凸包 \((i+j,f_{i+j})\)\(f_{i+j}=k(i+j)+b_{i+j}\),也就是启示我们 \(\text{WQS}\) 二分斜率后在每个凸包上找到对应斜率的值直接合并值,然后用 \(WQS\) 二分的方式算出答案
于是就做到 \(O(n\log n+Q\log V \log ^2n)\)

注意事项:

  1. \(\text{WQS}\) 一定要注意斜率变大或者变小会导致分的段数变多还是变少,同时二分写法要和求最值的写法相统一
    如本题写了分的段数 \(\ge k\) 时更新答案,那么求最值,值相等时优先取段数多的
  2. 传参事项,传 vector 时加个 & 就不会发生复制导致超时的问题了(因为这题某函数本身并不需要遍历整个 vector,只要也只能 \(O(\log n)\) 做某些特定事情复杂度才对)

\(\text{Code}\)

#include <bits/stdc++.h> 
#define IN inline
#define eb emplace_back
#define LL long long
#define Vec vector<LL>
using namespace std;

template<typename Tp>
IN void read(Tp &x) {
    x = 0; char ch = getchar(); int f = 0;
    for(; !isdigit(ch); f |= (ch == '-'), ch = getchar());
    for(; isdigit(ch); x = (x<<3)+(x<<1)+(ch^48), ch = getchar());
    if (f) x = ~x + 1;
}

const int N = 35005;
const LL INF = 2e9;
int n, a[N];
LL pans[2], pcnt[2], tans[2], tcnt[2];

struct SegmentTree {
    #define ls (p << 1)
    #define rs (ls | 1)
    Vec tr[N << 2][2][2];
    
    IN Vec merge(Vec &a, Vec &b) {
    	if (a.empty() || b.empty()) return{};
    	Vec ret(a.size() + b.size() - 1, -INF);
    	int l = 0, r = 0; if (a[0] != -INF && b[0] != -INF) ret[0] = a[0] + b[0];
    	while (l < a.size() || r < b.size()) {
    		if (l >= a.size() - 1 && r >= b.size() - 1) break;
    		if (l == a.size() - 1) ++r; else if (r == b.size() - 1) ++l;
    		else if (a[l + 1] - a[l] > b[r + 1] - b[r]) ++l; else ++r;
    		if (a[l] != -INF && b[r] != -INF) ret[l + r] = a[l] + b[r];
        }
    	return ret;
    }
    IN void shift(Vec tmp, Vec &res) {
        for(int i = 1; i < tmp.size(); i++) res[i - 1] = max(res[i - 1], tmp[i]);
    }
    IN void pushup(int p) {
        for(int i = 0; i < 2; i++)
            for(int j = 0; j < 2; j++) {
                tr[p][i][j] = merge(tr[ls][i][0], tr[rs][0][j]);
                shift(merge(tr[ls][i][1], tr[rs][1][j]), tr[p][i][j]);
            }
    }
    void build(int p, int l, int r) {
        if (l == r) {
            tr[p][0][0] = {0, a[l]}, tr[p][0][1] = tr[p][1][0] = tr[p][1][1] = {-INF, a[l]};
            return;
        }
        int mid = l + r >> 1; build(ls, l, mid), build(rs, mid + 1, r), pushup(p);
    }
    
    IN void update(Vec &a, LL k, int x, int y) {
        int l = 1, r = a.size() - 1, mid = l + r >> 1, ret = 0;
        for(; l <= r; mid = l + r >> 1)
            if (a[mid] - a[mid - 1] >= k) ret = mid, l = mid + 1; else r = mid - 1;
        if (a[ret] == -INF) return;
        for(int i = 0; i < 2; i++) {
            LL w = tans[i] + a[ret] - k * ret;
            if (!x && pans[y] <= w) pans[y] = w, pcnt[y] = tcnt[i] + ret;
            w = tans[i] + a[ret] - k * (ret - i);
            if (x && (pans[y] < w || (pans[y] == w && tcnt[i] + ret - i > pcnt[y])))
                pans[y] = w, pcnt[y] = tcnt[i] + ret - i;
        }
    }
    IN void Merge(int p, LL k) {
        tcnt[0] = pcnt[0], tcnt[1] = pcnt[1], tans[0] = pans[0], tans[1] = pans[1];
        pans[1] = -INF, pans[0] = pcnt[0] = pcnt[1] = 0;
        for(int i = 0; i < 2; i++)
            for(int j = 0; j < 2; j++) update(tr[p][i][j], k, i, j);
    }
    
    void query(int p, int l, int r, int x, int y, LL k) {
        if (x <= l && r <= y) return Merge(p, k), void();
        int mid = l + r >> 1;
        if (x <= mid) query(ls, l, mid, x, y, k);
        if (y > mid) query(rs, mid + 1, r, x, y, k);
    }
}seg;

void Query(int L, int R, int k) {
    LL res = 0, l = -1e10, r = 1e10, mid = l + r >> 1;
    for(; l <= r; mid = l + r >> 1) {
        pans[1] = -INF, pans[0] = pcnt[0] = pcnt[1] = 0, seg.query(1, 1, n, L, R, mid);
        int z = (pans[0] > pans[1] ? 0 : 1);
        if (pcnt[z] >= k) res = pans[z] + mid * k, l = mid + 1; else r = mid - 1;
    }
    printf("%lld\n", res);
}

int main() {
    freopen("maximize.in", "r", stdin);
    freopen("maximize.out", "w", stdout);
    int q; read(n), read(q);
    for(int i = 1; i <= n; i++) read(a[i]);
    seg.build(1, 1, n);
    for(int i = 1, l, r, k; i <= q; i++) read(l), read(r), read(k), Query(l, r, k);
}