[SDOI 2016]征途

时间:2022-03-02 04:56:06

Description

题库链接

将一个长度为 \(n\) 的正整数序列分为 \(m\) 段,问你这 \(m\) 段最小的方差 \(v\) 为多少。输出 \(v\times m^2\) 。

\(1\leq n\leq 3000\)

Solution

容易发现答案就是:

\[m^2\frac{\sum\limits_{i=1}^m(a_i-\overline{a})^2}{m}=m\sum\limits_{i=1}^m\left(a_i-\frac{\sum\limits_{i=1}^m a_i}{m}\right)^2\]

记 \(s=\sum\limits_{i=1}^m a_i\) ,

\[m\sum_{i=1}^m a_i^2-2s\sum_{i=1}^m a_i+s^2\]

这玩意就可以斜率优化了。

upd

做的时候制杖了...发现最后的式子还可以化成 \[m\sum_{i=1}^m a_i^2-s^2\]

Code

//It is made by Awson on 2018.3.19
#include <bits/stdc++.h>
#define LL long long
#define dob complex<double>
#define Abs(a) ((a) < 0 ? (-(a)) : (a))
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
#define Swap(a, b) ((a) ^= (b), (b) ^= (a), (a) ^= (b))
#define writeln(x) (write(x), putchar('\n'))
#define lowbit(x) ((x)&(-(x)))
using namespace std;
const int N = 3000;
void read(LL &x) {
char ch; bool flag = 0;
for (ch = getchar(); !isdigit(ch) && ((flag |= (ch == '-')) || 1); ch = getchar());
for (x = 0; isdigit(ch); x = (x<<1)+(x<<3)+ch-48, ch = getchar());
x *= 1-2*flag;
}
void print(LL x) {if (x > 9) print(x/10); putchar(x%10+48); }
void write(LL x) {if (x < 0) putchar('-'); print(Abs(x)); } LL n, m, sum[N+5], s;
LL f[N+5][N+5];
int q[N+5], head, tail; LL deltax(int p, int q) {return m*2*(sum[q]-sum[p]); }
LL deltay(int p, int q) {return m*sum[q]*sum[q]-m*sum[p]*sum[p]+s*2*sum[q]-s*2*sum[p]; }
void work() {
read(n), read(m); for (int i = 1; i <= n; i++) read(sum[i]), sum[i] += sum[i-1];
s = sum[n]; memset(f, 127/3, sizeof(f));
f[0][0] = 0;
for (int i = 1; i <= m; i++) {
head = tail = 0; q[tail++] = 0;
for (int j = 1; j <= n; j++) {
while (tail-head > 1 &&
f[i-1][q[head+1]]-f[i-1][q[head]]+deltay(q[head], q[head+1]) <=
sum[j]*deltax(q[head], q[head+1])) ++head;
f[i][j] = f[i-1][q[head]]+m*(sum[j]-sum[q[head]])*(sum[j]-sum[q[head]])-(sum[j]-sum[q[head]])*2*s;
while (tail-head > 1 &&
(f[i-1][q[tail-1]]-f[i-1][q[tail-2]]+deltay(q[tail-2], q[tail-1]))*deltax(q[tail-1], j) >=
(f[i-1][j]-f[i-1][q[tail-1]]+deltay(q[tail-1], j))*deltax(q[tail-2], q[tail-1])) --tail;
q[tail++] = j;
}
}
writeln(f[m][n]+s*s);
}
int main() {
work(); return 0;
}