[HDOJ4609]3-idiots(FFT,计数)

时间:2023-05-31 17:03:37

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4609

题意:n个数,问取三个数可以构成三角形的组合数。

FFT预处理出两个数的组合情况,然后枚举第三个数,计数去重。

 #include <bits/stdc++.h>
using namespace std; const double PI = acos(-1.0);
//复数结构体
typedef struct Complex {
double r,i;
Complex(double _r = 0.0,double _i = 0.0) {
r = _r; i = _i;
}
Complex operator +(const Complex &b) {
return Complex(r+b.r,i+b.i);
}
Complex operator -(const Complex &b) {
return Complex(r-b.r,i-b.i);
}
Complex operator *(const Complex &b) {
return Complex(r*b.r-i*b.i,r*b.i+i*b.r);
}
}Complex;
/*
* 进行FFT和IFFT前的反转变换。
* 位置i和 (i二进制反转后位置)互换
* len必须是2的幂
*/
void change(Complex y[],int len) {
int i,j,k;
for(i = , j = len/;i < len-; i++) {
if(i < j)swap(y[i],y[j]);
//交换互为小标反转的元素,i<j保证交换一次
//i做正常的+1,j左反转类型的+1,始终保持i和j是反转的
k = len/;
while( j >= k) {
j -= k;
k /= ;
}
if(j < k) j += k;
}
}
/*
* 做FFT
* len必须为2^k形式,
* on==1时是DFT,on==-1时是IDFT
*/
void fft(Complex y[],int len,int on) {
change(y,len);
for(int h = ; h <= len; h <<= ) {
Complex wn(cos(-on**PI/h),sin(-on**PI/h));
for(int j = ;j < len;j+=h) {
Complex w(,);
for(int k = j;k < j+h/;k++) {
Complex u = y[k];
Complex t = w*y[k+h/];
y[k] = u+t;
y[k+h/] = u-t;
w = w*wn;
}
}
}
if(on == -) {
for(int i = ;i < len;i++) {
y[i].r /= len;
}
}
} typedef long long LL;
const int maxn = ;
Complex x1[maxn];
int a[maxn/];
LL num[maxn], s[maxn];
int n, len, q; int main() {
// freopen("in", "r", stdin);
int T;
scanf("%d", &T);
while(T--) {
scanf("%d",&n);
memset(a, , sizeof(a));
memset(s, , sizeof(s));
memset(num, , sizeof(num));
int maxx = ;
for(int i = ; i < n; i++) {
scanf("%I64d", &a[i]);
num[a[i]]++;
maxx = max(maxx, a[i]);
}
int len1 = maxx + ;
len = ;
while(len < len1 * ) len <<= ;
for(int i = ; i < len; i++) x1[i] = Complex(, );
for(int i = ; i < len1; i++) x1[i] = Complex(num[i], );
fft(x1, len, );
for(int i = ; i < len; i++) x1[i] = x1[i] * x1[i];
fft(x1, len, -);
for(int i = ; i < len; i++) num[i] = (LL)(x1[i].r + 0.5);
len = * maxx;
for(int i = ; i < n; i++) num[a[i]*]--;
for(int i = ; i <= len; i++) num[i] /= ;
for(int i = ; i <= len; i++) s[i] = s[i-] + num[i];
LL ret = ;
for(int i = ; i < n; i++) {
ret += s[len] - s[a[i]];
ret -= (LL)(n - i - ) * i;
ret -= (n - );
ret -= (LL)(n - i - ) * (n - i - ) / ;
}
LL sum = (LL)n * (n - ) * (n - ) / ;
// cout << (double)ret/(double)((n*(n-1)*(n-2))/6) << endl;
printf("%.7lf\n", (double)ret/sum);
}
return ;
}