题意:给出n个数字,要你把这n个数字分成m堆,每一堆的价值是(max(val) - min(val)) ^ 2 要你求出分成m堆之后得到的最小价值
思路:排序,dp[i][j]表示前j个数字,分成i堆的最小价值,dp[i][j] = min(dp[i - 1][k] + (val[j] - val[k + 1]) ^ 2) 。
代码:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int MAX_N = 10007; typedef int ll; template <class T> inline bool rd(T &ret) { char c; int sgn; if(c = getchar() , c == EOF) return false; while(c != '-' && (c < '0' || c > '9')) c = getchar(); sgn = (c == '-') ? -1 : 1; ret = (c == '-') ? 0 : (c - '0'); while(c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0'); ret *= sgn; return true; } int n, m; int s[MAX_N >> 1][MAX_N], a[MAX_N]; ll dp[MAX_N >> 1][MAX_N]; int main() { int T; scanf("%d", &T); int cas = 0; while (T-- > 0) { scanf("%d%d", &n, &m); for (int i = 1; i <= n; ++i) rd(a[i]); sort(a + 1, a + 1 + n); memset(s, 0, sizeof s); for (int i = 1; i <= m; ++i) for (int j = i; j <= n; ++j) { if (j <= i) dp[i][j] = 0; else dp[i][j] = 0x7f7f7f7f; } for (int i = 1; i <= n; ++i) s[1][i] = 1, dp[1][i] = (a[i] - a[1]) * (a[i] - a[1]); for (int i = 2; i <= m; ++i) s[i][n + 1] = n; for (int i = 2; i <= m; ++i) for (int j = n; j >= i; --j) { for (int k = s[i - 1][j]; k <= s[i][j + 1]; ++k) { int w = (a[j] - a[k + 1]) * (a[j] - a[k + 1]); if (dp[i][j] > dp[i - 1][k] + w) { dp[i][j] = dp[i - 1][k] + w; s[i][j] = k; } } } printf("Case %d: %d\n", ++cas, dp[m][n]); } return 0; }
斜率优化:
对于第i个数,它要构成第j个集合时,假设它之前某两个数k1,k2,那么取k2为分割点比取k1为分割点优的条件就是:
dp[k2][j - 1] + (a[i] - a[k2 + 1])^2 < dp[k1][j - 1] + (a[i] - a[k1 + 1])^2
=> dp[k2][j - 1] + a[k2 + 1]^2 + a[i]^2 - 2 * a[i] * a[k2 + 1] < dp[k1][j - 1] + a[k1 + 1]^2 + a[i]^2 - 2 * a[i] * a[k1 + 1]
=> dp[k2][j - 1] - dp[k1][j - 1] + a[k2 + 1]^2 - a[k1 + 1]^2 < 2 * a[i] * (a[k2 + 1] - a[k1 + 1])
代码:
include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; int Case,n,m,head,tail; int a[10050],dp[5050][10050],q[10010]; int work(){ for (int i = 1;i <= n;i ++) dp[1][i] = (a[i] - a[1]) * (a[i] - a[1]); for (int i = 2;i <= m;i ++){ head = tail = 0; q[tail ++] = i - 1; for (int j = i;j <= n;j ++){ while (head + 1 < tail){ int p1 = q[head],p2 = q[head + 1]; int x1 = a[p1 + 1],x2 = a[p2 + 1]; int y1 = dp[i - 1][p1] + x1 * x1,y2 = dp[i - 1][p2] + x2 * x2; if (y2 - y1 <= 2 * a[j] * (x2 - x1)) head ++; else break; } int k = q[head]; dp[i][j] = dp[i - 1][k] + (a[j] - a[k + 1]) * (a[j] - a[k + 1]); while (head + 1 < tail){ int p1 = q[tail - 2],p2 = q[tail - 1],p3 = j; int x1 = a[p1 + 1],x2 = a[p2 + 1],x3 = a[p3 + 1]; int y1 = dp[i - 1][p1] + x1 * x1,y2 = dp[i - 1][p2] + x2 * x2,y3 = dp[i - 1][p3] + x3 * x3; if ((y3 - y1) * (x2 - x1) <= (y2 - y1) * (x3 - x1)) tail --; else break; } q[tail ++] = j; } } return dp[m][n]; } int main(){ int T; scanf("%d",&T); while (T --){ scanf("%d%d",&n,&m); for (int i = 1;i <= n;i ++) scanf("%d",&a[i]); sort(a + 1,a + 1 + n); printf("Case %d: %d\n",++ Case,work()); } return 0; }