[luogu3648][bzoj3675][APIO2014]序列分割【动态规划+斜率优化】

时间:2022-06-25 21:34:03

题目大意

让你把一个数列分成k+1个部分,使分成乘积分成各个段乘积和最大。

分析

首先肯定是无法开下n \(\times\) n的数组,那么来一个小技巧:因为我们知道k的状态肯定是从k-1的状态转移过来的,而且只从k-1的状态转移过来,那么我们就记录一下k-1和k的状态。
然后我们再来动态规划:
状态肯定是:\(f[i]\)表示前i个数,分成j段(j枚举,滚动数组优化成n)。
转移方程就是:
\[f[i]=max(g[j]+sum[j]\times(sum[i]-sum[k]))\]
一开始看错题目了,以为是乘积的乘积最大,而且是各个段的。一直过不了样例。
那么以上的转移方程就可以拿到一小部分分数了。


以下是斜率优化的部分:
按照斜率优化的套路,首先假设j>k,且j的状态比k要优。
那么得到了式子:
\[g[j]+sum[j]\times(sum[i]-sum[j])>g[k]+sum[k]\times(sum[i]-sum[k])\]
化简就得到了以下的不等式:(初中知识就可以了)

\[\frac{(g[j]-sum[j]^2)-(g[k]-sum[k]^2)}{sum[k]-sum[j]}<=sum[i]\]

那么单调队列维护凸包就可以了。

开了o2卡了常才过掉的垃圾代码

// luogu-judger-enable-o2
#include <bits/stdc++.h>
#define N 100005
#define ll long long
#define db double
using namespace std;
ll g[N][255], f[N], dp[N], sum[N], a[N];
int n, m;
int q[N], ans[N];
template <typename T>
inline void read(T &x) {
    x = 0; T fl = 1;
    char ch = 0;
    while (ch < '0' || ch > '9') {
        if (ch == '-') fl = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar();
    }
    x *= fl;
}
ll X(int i) {
    return sum[i];
}
ll Y(int i) {
    return f[i] - sum[i] * sum[i];
}
db slope(int i, int j) {
    if (sum[i] == sum[j]) return - 1e18;
    return ((1.0 * (Y(i) - Y(j))) / (1.0 * (X(j) - X(i))));
}
int main() {
    read(n); read(m);
    for (int i = 1; i <= n; i ++) {
        read(a[i]);
        sum[i] = sum[i - 1] + a[i];
    }
    for (int j = 1; j <= m; j ++) {
        int h = 0, t = 0;
        for (int i = 1; i <= n; i ++) {
            while (h < t && slope(q[h], q[h + 1]) <= sum[i]) ++ h;
            int k = q[h];
            dp[i] = f[k] + sum[k] * (sum[i] - sum[k]);
            g[i][j] = k;
            while (h < t && slope(q[t - 1], q[t]) >= slope(q[t], i)) -- t;
            q[++ t] = i;
        }
        for (int i = 1; i <= n; i ++) f[i] = dp[i];
    }
    printf("%lld\n", dp[n]);
    int tot = 0;
    for (int x = n, i = m; i >= 1; i --) {
        x = g[x][i];
        ans[++ tot] = x;
    }
    for (int i = tot; i >= 1; i --) printf("%d ", ans[i]);
    return 0;
}