HDU 4609 3-idiots ——(FFT)

时间:2021-09-25 02:10:15

  这是我接触的第一个关于FFT的题目,留个模板。

  这题的题解见:http://www.cnblogs.com/kuangbin/archive/2013/07/24/3210565.html

  FFT的模板如下:

 #include<bits/stdc++.h>
using namespace std;
const double pi = atan(1.0)*;
struct Complex {
double x,y;
Complex(double _x=,double _y=)
:x(_x),y(_y) {}
Complex operator + (Complex &tt) { return Complex(x+tt.x,y+tt.y); }
Complex operator - (Complex &tt) { return Complex(x-tt.x,y-tt.y); }
Complex operator * (Complex &tt) { return Complex(x*tt.x-y*tt.y,x*tt.y+y*tt.x); }
};
Complex a[],b[];
void fft(Complex *a, int n, int rev) {
// n是(大于等于相乘的两个数组长度)2的幂次 ; 比如长度是5 ,那么 n = 8 2^2 < 5 2^3 > 5
// 从0开始表示长度,对a进行操作
// rev==1进行DFT,==-1进行IDFT
for (int i = ,j = ; i < n; ++ i) {
for (int k = n>>; k > (j^=k); k >>= );
if (i<j) std::swap(a[i],a[j]);
}
for (int m = ; m <= n; m <<= ) {
Complex wm(cos(*pi*rev/m),sin(*pi*rev/m));
for (int i = ; i < n; i += m) {
Complex w(1.0,0.0);
for (int j = i; j < i+m/; ++ j) {
Complex t = w*a[j+m/];
a[j+m/] = a[j] - t;
a[j] = a[j] + t;
w = w * wm;
}
}
}
if (rev==-) {
for (int i = ; i < n; ++ i) a[i].x /= n,a[i].y /= n;
}
}
int main(){
a[] = Complex(,); // a[0]: x的0次项。
a[] = Complex(,);
a[] = Complex(,);
a[] = Complex(,); b[] = Complex(,);
b[] = Complex(,);
b[] = Complex(,);
b[] = Complex(,);
fft(a,,);
fft(b,,);
for(int i = ; i < ; i ++){
a[i] = a[i] * b[i];
}
fft(a,,-);
for(int i = ; i < ; i ++){
cout << i << " " << a[i].x << endl;;
}
/*
* fft:nlogn求两个多项式相乘,(原来要n^2)
*
* f1 = 0x^0 + 1x^1 + 2x^2 + 3x^3
* f2 = 3 + 2x^1 + x^2 + 0x^3;
*
* f1 = a + b + c + d;
* f2 = x + y + z + k;
* f3 = __
* dp[1] dp[2] dp[3] dp[4] dp[5];
* c[0] c[1] c[2] c[3];
*/
/*
0 0 * x^0
1 3 * x^1
2 8 * x^2
3 14
4 8 -----> 0x^0 + 3x^1 + 8x^2 ...... 3x^5
5 3
6 0 * x^6
7 0 * x^7
* */
return ;
}

FFT模板

  关于这个模板有几点需要注意的:  

  1.在系数转化成整数时,会有精度误差,需要加eps。

  2.假设a和b之前的长度都是n,卷积以后的大小应该是2*n-1,再考虑到fft中第二个参数n必须是大于等于卷积长度的2的幂次,因此最后的数组长度必须是n的4倍,也就是说这里的a数组大小应该开4倍才行。

  3.注意在对a和b进行fft(a, LIM, 1)之前需要对n+1~LIM之间的复数也初始化为Complex(0, 0)以免多组测试中之前的操作对当前的操作产生影响。

  最后,本题的AC代码如下:

 #include<bits/stdc++.h>
using namespace std;
const double pi = atan(1.0)*;
const int N = 1e5 + ;
typedef long long ll; struct Complex {
double x,y;
Complex(double _x=,double _y=)
:x(_x),y(_y) {}
Complex operator + (Complex &tt) { return Complex(x+tt.x,y+tt.y); }
Complex operator - (Complex &tt) { return Complex(x-tt.x,y-tt.y); }
Complex operator * (Complex &tt) { return Complex(x*tt.x-y*tt.y,x*tt.y+y*tt.x); }
};
Complex a[N*],b[N];
void fft(Complex *a, int n, int rev) {
// n是(大于等于相乘的两个数组长度)2的幂次 ; 比如长度是5 ,那么 n = 8 2^2 < 5 2^3 > 5
// 从0开始表示长度,对a进行操作
// rev==1进行DFT,==-1进行IDFT
for (int i = ,j = ; i < n; ++ i) {
for (int k = n>>; k > (j^=k); k >>= );
if (i<j) std::swap(a[i],a[j]);
}
for (int m = ; m <= n; m <<= ) {
Complex wm(cos(*pi*rev/m),sin(*pi*rev/m));
for (int i = ; i < n; i += m) {
Complex w(1.0,0.0);
for (int j = i; j < i+m/; ++ j) {
Complex t = w*a[j+m/];
a[j+m/] = a[j] - t;
a[j] = a[j] + t;
w = w * wm;
}
}
}
if (rev==-) {
for (int i = ; i < n; ++ i) a[i].x /= n,a[i].y /= n;
}
} int A[N];
ll num[N*], sum[N*]; int main(){
/*a[0] = Complex(0,0); // a[0]: x的0次项。
a[1] = Complex(1,0);
a[2] = Complex(2,0);
a[3] = Complex(3,0); b[0] = Complex(3,0);
b[1] = Complex(2,0);
b[2] = Complex(1,0);
b[3] = Complex(0,0);
fft(a,8,1);
fft(b,8,1);
for(int i = 0 ; i < 8 ; i ++){
a[i] = a[i] * b[i];
}
fft(a,8,-1);
for(int i = 0 ; i < 8 ; i ++){
cout << i << " " << a[i].x << endl;;
}*/
//a[0] = Complex(0, 0); b[0] = Complex(0, 0);
int T; scanf("%d",&T);
while(T--)
{
int n; scanf("%d",&n);
memset(num,,sizeof num);
for(int i=;i<=n;i++)
{
scanf("%d",A+i);
num[A[i]]++;
}
sort(A+, A++n);
int len = A[n];
int LIM = ;
while()
{
if(LIM >= len*+) break;
else LIM <<= ;
}
for(int i=;i<=len;i++)
{
a[i] = Complex(num[i], );
}
for(int i=len+;i<LIM;i++)
{
a[i] = Complex(, );
}
fft(a,LIM,);
for(int i=;i<LIM;i++) a[i] = a[i] * a[i];
fft(a,LIM,-);
// finish fft
len = len * ;
for(int i=;i<=len;i++) num[i] = (ll)(a[i].x + 0.5);
for(int i=;i<=n;i++) num[A[i]*]--; // 减去两个相同的组合
for(int i=;i<=len;i++) num[i] /= ; // 选择的是无序的
for(int i=;i<=len;i++) sum[i] = sum[i-] + num[i];
ll ans = ;
for(int i=;i<=n;i++)
{
int now = A[i];
ans += sum[len] - sum[now]; // 对于每个数,加起来比它大的都是可行的
ans -= 1LL * (i-) * (n-i); // 减去一个大的一个小的组合的情况
ans -= 1LL * (n-); // 减去自己和任意一个组合的情况
ans -= 1LL * (n-i) * (n-i-) / ; // 减去比它大的两个组合的情况
// 剩下的就是两个小的加起来比A[i]大的情况
}
ll all = 1LL * n*(n-)*(n-) / ;
printf("%.7f\n",1.0*ans/all);
}
return ;
}

AC代码