HDU4609 FFT+组合计数

时间:2023-12-04 15:26:38

HDU4609 FFT+组合计数

传送门:http://acm.hdu.edu.cn/showproblem.php?pid=4609

题意:

找出n根木棍中取出三根木棍可以组成三角形的概率

题解:

我们统计每种长度的棍子的个数

我们对于长度就有一个多项式

\[f=num[0]*i_0+num[1]*i_1+num[2]*i_2.....num[len]*i_len
\]

我们考虑两根棍子可以组成所有长度的方案数

所以我们对num数组求一次FFT

两根棍子组成长度的上界是\(len_{max}*2\)

可能存在棍子重复组合的情况,这个时候我们需要去重

去掉两种重复的情况

1.自己和自己组合 即去除a[i]+a[i]的情况

2.A和B组合 B又和A组合的情况 这种时候每个组合/2即可

然后通过组合数计数即可,tips:在过程中可能爆int

代码:

#include <set>
#include <map>
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
typedef pair<int, int> pii;
typedef unsigned long long uLL;
#define ls rt<<1
#define rs rt<<1|1
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define bug printf("*********\n")
#define FIN freopen("input.txt","r",stdin);
#define FON freopen("output.txt","w+",stdout);
#define IO ios::sync_with_stdio(false),cin.tie(0)
#define debug1(x) cout<<"["<<#x<<" "<<(x)<<"]\n"
#define debug2(x,y) cout<<"["<<#x<<" "<<(x)<<" "<<#y<<" "<<(y)<<"]\n"
#define debug3(x,y,z) cout<<"["<<#x<<" "<<(x)<<" "<<#y<<" "<<(y)<<" "<<#z<<" "<<z<<"]\n"
const int maxn = 3e5 + 5;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const double Pi = acos(-1);
LL quick_pow(LL x, LL y) {
LL ans = 1;
while(y) {
if(y & 1) {
ans = ans * x % mod;
} x = x * x % mod;
y >>= 1;
} return ans;
}
struct complex {
double x, y;
complex(double xx = 0, double yy = 0) {
x = xx, y = yy;
}
} x[maxn]; int a[maxn];
complex operator + (complex a, complex b) {
return complex(a.x + b.x, a.y + b.y);
}
complex operator - (complex a, complex b) {
return complex(a.x - b.x, a.y - b.y);
}
complex operator * (complex a, complex b) {
return complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
} int n, m;
int l, r[maxn];
int limit = 1;
void fft(complex *A, int type) {
for(int i = 0; i < limit; i++) {
if(i < r[i]) swap(A[i], A[r[i]]);
}
for(int mid = 1; mid < limit; mid <<= 1) {
complex Wn(cos(Pi / mid), type * sin(Pi / mid));
for(int R = mid << 1, j = 0; j < limit; j += R) {
complex w(1, 0);
for(int k = 0; k < mid; k++, w = w * Wn) {
complex x = A[j + k], y = w * A[j + mid + k];
A[j + k] = x + y;
A[j + k + mid] = x - y;
}
}
}
}
LL num[maxn];//100000*100000会超int
LL sum[maxn];
int main() {
#ifndef ONLINE_JUDGE
FIN
#endif
int T;
scanf("%d", &T);
while(T--) {
int n;
memset(num, 0, sizeof(num));
scanf("%d", &n);
for(int i = 0; i < n; i++) {
scanf("%d", &a[i]);
num[a[i]]++;
}
sort(a, a + n);
int len1 = a[n - 1] + 1;
limit = 1;
l = 0;
while(limit < 2 * len1) limit <<= 1, l++;
for(int i = 0; i < len1; i++) {
x[i] = complex(num[i], 0);
}
for(int i = len1; i < limit ; i++) {
x[i] = complex(0, 0);
}
for(int i = 0; i < limit; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
fft(x, 1);
for(int i = 0; i < limit; i++) {
x[i] = x[i] * x[i];
}
fft(x, -1);
for(int i = 0; i < limit; i++) {
x[i].x /= limit;
}
for(int i = 0; i < limit; i++) {
num[i] = (LL)(x[i].x + 0.5);
// debug1(num[i]);
}
limit = 2 * a[n - 1];
//去重,去除 a_i,a_i这种情况
for(int i = 0; i < n; i++) {
num[a[i] + a[i]]--;
}
//去重,去除 (a_i,a_j),(a_j,a_i)这种情况
for(int i = 1; i <= limit; i++) {
num[i] /= 2;
}
sum[0] = 0;
for(int i = 1; i <= limit; i++)
sum[i] = sum[i - 1] + num[i];
LL cnt = 0;
for(int i = 0; i < n; i++) {
cnt += sum[limit] - sum[a[i]];
//减掉一个取大,一个取小的
cnt -= (long long)(n - 1 - i) * i;
//减掉一个取本身,另外一个取其它
cnt -= (n - 1);
//减掉大于它的取两个的组合
cnt -= (long long)(n - 1 - i) * (n - i - 2) / 2;
}
//总数
long long tot = (long long)n * (n - 1) * (n - 2) / 6;
printf("%.7f\n", (double)cnt / tot); }
return 0;
}