[poj3017] Cut the Sequence (DP + 单调队列优化 + 平衡树优化)

时间:2021-08-04 21:15:03

DP + 单调队列优化 + 平衡树 好题


Description

Given an integer sequence { an } of length N, you are to cut the sequence into several parts every one of which is a consecutive subsequence of the original sequence. Every part must satisfy that the sum of the integers in the part is not greater than a given integer M. You are to find a cutting that minimizes the sum of the maximum integer of each part.

Input

The first line of input contains two integer N (0 < N ≤ 100 000), M. The following line contains N integers describes the integer sequence. Every integer in the sequence is between 0 and 1 000 000 inclusively.

Output

Output one integer which is the minimum sum of the maximum integer of each part. If no such cuttings exist, output −1.

Sample Input

8 17
2 2 2 8 1 8 2 1

Sample Output

12

Hint

Use 64-bit integer type to hold M.


题目大意

给你一个长度为 n 序列,要你把它分成不同块,使得每块最大值的和最小,每块内数字的和不超过 m。题意很简单。

题解

这道题是一道非常好的单调队列优化 DP 的题。首先,我们忽略掉 m 这个条件,我们可以很直接地看出,可以设 dp[i] 为前 i 个数可以得到的最优答案。转移方程为:

dp[i] = min{ dp[j] + max{a[j+1], a[j+2] ... a[i]} };

朴素的转移时间复杂度为\(O(n^2)\), 所以我们考虑用某种方法来优化掉多出来的这个\(n\)
首先我们来看怎样优化求区间内最大值,我们可以想到用rmq,但很快就被否定,再一想,单调队列!!
很明显,dp[j] 是单调不下降的,所以,当当前块内最大值确定时,j 越小越好,我们可以模拟边界 j 由后向前推进时的情况,如果 a[j] 大于了当前的最大值,那么,当前块内的max值会增大,否则不变。
这个不变就为我们提供了优化的前提,我们可以根据把一段不变的区间压缩成只更新一次,这样,我们就去掉了多余的决策。
于是,我们维护一个单调队列 q[], q[] 中维护的是单调下降的下标,
如样例 对a[4~8] {8,1,8,2,1},i = 8, 单调队列中存的值是{6,7,8},返回a[]中的值为{8,2,1}。
当 4 <= j <= 6 时,max{a[j] ... a[i]} 都等于 8;
当7 <= j <= 7 时,max{...}等于 2;
当8 <= j <= 8 时,max{...}等于 1;
然后我们加上限制条件 m,我们只用从单调队列头部删除不满足要求的值即可。

然后我们只用从这几个点来更新,是不是优化了很多?
等一等,如果数组a[] 本身就是单调下降的怎么办,在这种情况下,时间复杂度依然可以高达\(O(n^2)\)。不用急,我们再继续往下探究,看是否可以在\(O(1)\)\(O(log_n)\)的复杂度内快速找到使答案最小的那个 j 。

对于一个 j,如果其在更新 dp[i - 1] 时就已经出现,并且在更新 dp[i] 时没有被删除,那么可以得到:

dp[j-1] + max{a[j] ... a[i-1]} == dp[j-1] + max{a[j] ... a[i]}

即,我们可以用本来是更新 dp[i - 1] 的 j 来继续更新 dp[i]。
所以我们可以将 dp[j-1] + max{a[j] ... a[i]} 用一颗平衡树来维护,每次直接取出最小值来更新 dp[i] 即可。
于是这道题就完美解决了,时间复杂度为\(O(nlog_n)\)

但是因为数据太弱,我们用 STL 中的 set 就可以通过,set 不仅常数太大,并且erase是直接在树中删掉某个数,而不是使值减一,这样,如果出现两个相同的数同时出现在 set 中,在删除时就会导致答案错误。所以,正解应该手写一颗平衡树。

更有甚者,理论复杂度为\(O(n^2)\)的不用平衡树的算法居然比手写平衡树还要快。。。poj你数据是有多水 2333


UPD: 之前有一点错了, set确实是不行, 但是STL中还提供了multiset多元集合,支持多个相同元素。用法和set几乎一样。

下面是\(O(n^2)\)\(O(nlog_n)\)两份代码。

\(O(n^2)\)代码

#include <iostream>
#include <set>
#include <cstdio>
using namespace std;
typedef long long LL;

const int maxn = 1e5 + 5;
int n;
LL m;
LL a[maxn],sum[maxn];
int q[maxn];
LL dp[maxn];

int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    while(scanf("%d%lld",&n,&m) != EOF) {
        int ok = 0;
        for(int i = 1;i <= n;i++) {
            scanf("%lld",a+i);if(a[i] > m) ok = 1;
            sum[i] = a[i] + sum[i-1]; 
        }
        int p = 1,front = 0,tail = -1;
        for(int i = 1;i <= n;i++) {
            if(ok)break;
            while(sum[i] - sum[p-1] > m)p++;
            
            while(front <= tail && a[q[tail]] <= a[i]) tail--; 
            q[++tail] = i;
            
            while(q[front] < p) front++;
            
            dp[i] = dp[p-1] + a[q[front]];
            for(int j = front;j < tail;j++) dp[i] = min(dp[i],dp[q[j]] + a[q[j+1]]);
        }
        if(ok) dp[n] = -1;
        printf("%lld\n",dp[n]);
    }
    
    
    return 0;
}

\(O(nlog_n)\)代码

#include <iostream>
#include <set>
#include <cstdio>
using namespace std;
typedef long long LL;

const int maxn = 1e5 + 5;
int n;
LL m;
LL a[maxn],sum[maxn];
int q[maxn];
LL dp[maxn];

multiset <LL> s;

int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    while(scanf("%d%lld",&n,&m) != EOF) {
        int ok = 0;
        for(int i = 1;i <= n;i++) {
            scanf("%lld",a+i);if(a[i] > m) ok = 1;
            sum[i] = a[i] + sum[i-1]; 
        }
        int p = 1,front = 0,tail = -1;
        s.clear();
        for(int i = 1;i <= n;i++) {
            if(ok)break;
            while(sum[i] - sum[p-1] > m)p++;
            
            while(front <= tail && a[q[tail]] <= a[i]) {
                if(front < tail) s.erase(dp[q[tail-1]] + a[q[tail]]);
                tail--;
            } 
            q[++tail] = i;
            
            if(front < tail) s.insert(dp[q[tail-1]] + a[i]);
            
            while(q[front] < p) {
                if(front < tail) s.erase(dp[q[front]] + a[q[front+1]]);
                front++;
            }
            
            dp[i] = dp[p-1] + a[q[front]];
            if(front < tail) dp[i] = min(dp[i], *s.begin());
        }
        if(ok) dp[n] = -1;
        printf("%lld\n",dp[n]);
    }
    
    return 0;
}