POJ3977 Subset

时间:2023-11-22 18:41:56

嘟嘟嘟



这个数据范围显然是折半搜索。

把序列分成两半,枚举前一半的子集,存下来。然后再枚举后一半的子集,二分查找。

细节:

1.最优解可能只在一半的子集里,所以枚举的时候也要更新答案。

2.对于当前结果\(tot\),二分查找\(-tot\)的时候要把\(-tot\)两边的元素都和\(tot\)加起来试一下,而不是只加当前二分查找到的值。

3.用二进制枚举比较快。

4.得去重,即和相同,保留元素个数最小的集合。(扫一遍即可)

5.更新的时候元素个数别忘了是两部分之和,刚开始我因为这个没写对拍了好一会儿(而且小数据还都过了……)。

6.\(INF\)要开到\(1e18\),因为数据范围是\(a _ i \leqslant 1e15\)。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define rg register
typedef long long ll;
typedef double db;
const ll INF = 1e18;
const db eps = 1e-8;
const int maxn = 36;
const int maxp = 3e5 + 5;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
} int n, m, q, cnt = 0;
ll a[maxn];
struct Node
{
ll sum; int num;
bool operator < (const Node& oth)const
{
return sum < oth.sum || (sum == oth.sum && num < oth.num);
}
}t[maxp], s[maxp];
ll Min = INF;
int ans; ll Abs(ll x) {return x < 0 ? -x : x;} int main()
{
while(scanf("%d", &n) && n)
{
m = n >> 1; q = n - m; cnt = 0;
Min = INF; ans = 1;
for(int i = 1; i <= n; ++i) a[i] = read();
for(int i = 1; i <= n; ++i) Min = min(Min, Abs(a[i]));
for(int i = 1; i < (1 << m); ++i)
{
ll tot = 0; int tcnt = 0;
for(int j = 0; j < m; ++j)
if((1 << j) & i) tot += a[j + 1], tcnt++;
if(Abs(tot) < Min) Min = Abs(tot), ans = tcnt;
else if(Abs(tot) == Min) ans = min(ans, tcnt);
t[++cnt] = (Node){tot, tcnt};
}
sort(t + 1, t + cnt + 1);
int scnt = 0, x = 1;
for(int i = 2; i <= cnt; ++i)
{
if(t[i].sum != t[x].sum) s[++scnt] = t[x], x = i;
else t[x].num = min(t[x].num, t[i].num);
}
if(t[x].sum != s[scnt].sum) s[++scnt] = t[x];
for(int i = 1; i < (1 << q); ++i)
{
ll tot = 0; int tcnt = 0;
for(int j = 0; j < q; ++j)
if((1 << j) & i) tot += a[j + m + 1], tcnt++;
if(Abs(tot) < Min) Min = Abs(tot), ans = tcnt;
else if(Abs(tot) == Min) ans = min(ans, tcnt);
int pos = lower_bound(s + 1, s + scnt + 1, (Node){-tot, 0}) - s;
if(pos && pos <= scnt)
{
ll tp = Abs(tot + s[pos].sum);
if(tp < Min) Min = tp, ans = s[pos].num + tcnt;
else if(tp == Min) ans = min(ans, s[pos].num + tcnt);
}
if(pos - 1 > 0 && pos - 1 <= scnt)
{
ll tp = Abs(tot + s[pos - 1].sum);
if(tp < Min) Min = tp, ans = s[pos - 1].num + tcnt;
else if(tp == Min) ans = min(ans, s[pos - 1].num + tcnt);
}
}
write(Min), space, write(ans), enter;
}
return 0;
}