BZOJ3513[MUTC2013]idiots——FFT+生成函数

时间:2023-12-04 15:31:09

题目描述

给定n个长度分别为a_i的木棒,问随机选择3个木棒能够拼成三角形的概率。

输入

第一行T(T<=100),表示数据组数。
接下来若干行描述T组数据,每组数据第一行是n,接下来一行有n个数表示a_i。
3≤N≤10^5,1≤a_i≤10^5

输出

T行,每行一个整数,四舍五入保留7位小数。

样例输入

2
4
1 3 3 4
4
2 3 3 4

样例输出

0.5000000
1.0000000

提示

T<=20

N<=100000

首先开一个桶就可以得到长度分别为[1,100000]的木棒个数,只要将桶自己与自己卷积FFT一下就能得到两个木棒组成的任意长度的方案数(注意去重)。三个木棒不合法的情况当且仅当两个木棒之和小于等于第三个木棒,对桶求一个后缀和(或对方案数求一个前缀和)即可。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<cstdio>
#include<bitset>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const double pi=acos(-1.0);
int n,T,x;
ll t[400010];
struct miku
{
double x,y;
miku(double X=0,double Y=0){x=X,y=Y;}
}f[400010];
miku operator + (miku a,miku b){return miku(a.x+b.x,a.y+b.y);}
miku operator - (miku a,miku b){return miku(a.x-b.x,a.y-b.y);}
miku operator * (miku a,miku b){return miku(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int l,r[400010];
int a[100010];
int mask;
inline void DFT(miku *A)
{
for(int i=0;i<mask;i++)
{
if(i<r[i])
{
swap(A[i],A[r[i]]);
}
}
for(int mid=1;mid<mask;mid<<=1)
{
miku id(cos(pi/mid),sin(pi/mid));
for(int i=mid<<1,j=0;j<mask;j+=i)
{
miku w(1,0);
for(int k=0;k<mid;k++,w=w*id)
{
miku x=A[j+k],y=w*A[j+k+mid];
A[j+k]=x+y;
A[j+k+mid]=x-y;
}
}
}
}
inline void IDFT(miku *A)
{
for(int i=0;i<mask;i++)
{
if(i<r[i])
{
swap(A[i],A[r[i]]);
}
}
for(int mid=1;mid<mask;mid<<=1)
{
miku id(cos(pi/mid),-1.0*sin(pi/mid));
for(int i=mid<<1,j=0;j<mask;j+=i)
{
miku w(1,0);
for(int k=0;k<mid;k++,w=w*id)
{
miku x=A[j+k],y=w*A[j+k+mid];
A[j+k]=x+y;
A[j+k+mid]=x-y;
}
}
}
}
int main()
{
scanf("%d",&T);
mask=1;
l=0;
while(mask<=200000)
{
mask<<=1;
l++;
}
for(int i=0;i<mask;i++)
{
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
while(T--)
{
scanf("%d",&n);
memset(t,0,sizeof(t));
int mx=0;
for(int i=0;i<mask;i++)
{
f[i]=0;
}
for(int i=1;i<=n;i++)
{
scanf("%d",&x);
a[i]=x;
f[x].x++;
mx=max(mx,x);
}
DFT(f);
for(int i=0;i<mask;i++)
{
f[i]=f[i]*f[i];
}
IDFT(f);
for(int i=0;i<mask;i++)
{
f[i].x/=mask;
}
for(int i=1;i<=n;i++)
{
f[a[i]<<1].x--;
}
for(int i=1;i<=mx;i++)
{
t[i]=t[i-1]+(ll)(f[i].x/2+0.1);
}
ll ans=0;
for(int i=1;i<=n;i++)
{
ans+=t[a[i]];
}
printf("%.7f\n",1-(1.0*ans/(1.0*n*(n-1)/2*(n-2)/3)));
}
}