Fast Fourier Transform

时间:2022-01-23 12:35:43

写在前面的..

感觉自己是应该学点新东西了.. 所以就挖个大坑,去学FFT了..


FFT是个啥?

坑已补上..

推荐去看黑书《算法导论》,讲的很详细


例题选讲


1.UOJ #34. 多项式乘法

这是FFT最裸的题目了 FFT就是拿来求这个东西的

没啥好讲的,把板子贴一下吧..

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
using namespace std;
const int Maxn = ;
struct cp {
double r, i;
cp ( double r=, double i= ) : r(r), i(i) {}
}A[Maxn], B[Maxn]; int an, bn, n;
cp operator + ( cp a, cp b ){ return cp ( a.r+b.r, a.i+b.i ); }
cp operator - ( cp a, cp b ){ return cp ( a.r-b.r, a.i-b.i ); }
cp operator * ( cp a, cp b ){ return cp ( a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r ); }
double pi = acos(-);
void fft ( cp *a, int op ){
int i, j, k;
j = ;
for ( i = ; i < n; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = n>>;
while ( j & k ){ j -= k; k >>= ; }
j += k;
}
for ( i = ; i <= n; i <<= ){
cp wn = cp ( cos (2.0*pi/i), op * sin (2.0*pi/i) );
for ( j = ; j < n; j += i ){
cp w = cp ( , );
for ( k = j; k < j+i/; k ++ ){
cp x = a[k], y = w*a[k+i/];
a[k] = x+y; a[k+i/] = x-y;
w = w*wn;
}
}
}
if ( op == - ) for ( i = ; i < n; i ++ ) a[i].r /= n;
}
int main (){
int i, j, k;
scanf ( "%d%d", &an, &bn );
n = ; while ( n < an+bn+ ) n <<= ;
for ( i = ; i <= an; i ++ ) scanf ( "%lf", &A[i].r );
for ( i = ; i <= bn; i ++ ) scanf ( "%lf", &B[i].r );
fft ( A, ); fft ( B, );
for ( i = ; i < n; i ++ ) A[i] = A[i]*B[i];
fft ( A, - );
for ( i = ; i <= an+bn; i ++ ) printf ( "%d ", (int)(A[i].r+0.5) );
printf ( "\n" );
return ;
}

2.bzoj 3527: [Zjoi2014]力

这题在bzoj上没有题面,百度一下就行了,给个链接吧

化简一下公式:$$E_i=\sum\limits_{j<i}\dfrac{q_j}{(i-j)^2}-\sum\limits_{j>i}\dfrac{q_j}{(i-j)^2}$$

然后你设$g_i=\dfrac{1}{i^2}$,只考虑前面一部分的公式就可化成:$$E_i=\sum\limits_{j<i}q_j\times g_{i-j}$$

看是不是很眼熟,这就是很典型的FFT的模型

然后后面的就把整个数组反过来搞就是同样的道理..

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
using namespace std;
const int Maxn = ;
struct cp {
double r, i;
cp ( double r=0.0, double i=0.0 ) : r(r), i(i) {}
}g[Maxn*], h[Maxn*]; int n, nn;
cp operator + ( cp a, cp b ){ return cp ( a.r+b.r, a.i+b.i ); }
cp operator - ( cp a, cp b ){ return cp ( a.r-b.r, a.i-b.i ); }
cp operator * ( cp a, cp b ){ return cp ( a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r ); }
double pi = acos (-);
void fft ( cp *a, int op ){
int i, j, k;
for ( i = j = ; i < n; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = n >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
for ( i = ; i <= n; i <<= ){
cp wn = cp ( cos (2.0*pi/i), op * sin (2.0*pi/i) );
for ( j = ; j < n; j += i ){
cp w = cp ( , );
for ( k = j; k < j+i/; k ++ ){
cp x = a[k], y = w*a[k+i/];
a[k] = x+y; a[k+i/] = x-y;
w = w*wn;
}
}
}
if ( op == - ) for ( i = ; i < n; i ++ ) a[i].r /= n;
}
double q[Maxn], ans[Maxn];
int main (){
int i, j, k;
scanf ( "%d", &nn );
for ( i = ; i <= nn; i ++ ) scanf ( "%lf", &q[i] );
n = ; while ( n < *nn+ ) n <<= ;
for ( i = ; i <= nn; i ++ ){
g[i].r = q[i];
h[i].r = (double)/(double(i)*i);
}
fft ( g, ); fft ( h, );
for ( i = ; i < n; i ++ ) g[i] = g[i]*h[i];
fft ( g, - );
for ( i = ; i <= nn; i ++ ) ans[i] += g[i].r; for ( i = ; i <= n; i ++ ) g[i].r = g[i].i = h[i].r = h[i].i = ;
for ( i = ; i <= nn; i ++ ){
g[i].r = q[nn-i+];
h[i].r = (double)/(double(i)*i);
}
fft ( g, ); fft ( h, );
for ( i = ; i < n; i ++ ) g[i] = g[i]*h[i];
fft ( g, - );
for ( i = ; i <= nn; i ++ ) ans[i] -= g[nn-i+].r;
for ( i = ; i <= nn; i ++ ) printf ( "%lf\n", ans[i] );
return ;
}

3.bzoj 3160: 万径人踪灭

题意很长,但有用的不过那两句划线的..

仔细想想,如果$s_i==s_j$,那么就可以对$\dfrac{i+j}{2}$贡献答案了对吧

再想想,如果我已经知道了这$2n-1$(包括了中间的插缝)位置有$x$对数贡献了答案,在不考虑连续一段、取空集的情况下是可以有$2^x$种答案的

那么就在上面的基础上减去连续一段和空集的情况

连续一段的可以用manacher判断出,空集直接减去$1$就好了

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
using namespace std;
const int Maxn = ;
const int Mod = 1e9+;
struct cp {
double r, i;
cp ( double r=0.0, double i=0.0 ) : r(r), i(i) {}
}A[Maxn*];
cp operator + ( cp a, cp b ){ return cp ( a.r+b.r, a.i+b.i ); }
cp operator - ( cp a, cp b ){ return cp ( a.r-b.r, a.i-b.i ); }
cp operator * ( cp a, cp b ){ return cp ( a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r ); }
int n, nn;
char s[Maxn*]; int ans[Maxn*];
double pi = acos (-);
int _min ( int x, int y ){ return x < y ? x : y; }
void fft ( cp *a, int op ){
int i, j, k;
for ( i = j = ; i < n; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = n >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
for ( i = ; i <= n; i <<= ){
cp wn = cp ( cos (2.0*pi/i), op*sin(2.0*pi/i) );
for ( j = ; j < n; j += i ){
cp w = cp ( , );
for ( k = j; k < j+i/; k ++ ){
cp x = a[k], y = w*a[k+i/];
a[k] = x+y; a[k+i/] = x-y;
w = w*wn;
}
}
}
if ( op == - ) for ( i = ; i < n; i ++ ) a[i].r /= n;
}
int d[Maxn];
int rad[Maxn*];
int main (){
int i, j, k;
scanf ( "%s", s );
nn = strlen (s);
d[] = ;
for ( i = ; i <= nn; i ++ ) d[i] = ( d[i-] * ) % Mod;
n = ; while ( n < *nn+ ) n <<= ;
for ( i = ; i < nn; i ++ ){
if ( s[i] == 'a' ) A[i].r = ;
}
fft ( A, );
for ( i = ; i < n; i ++ ) A[i] = A[i]*A[i];
fft ( A, - );
for ( i = ; i <= *(nn-); i ++ ) ans[i] += ((int)(A[i].r+0.5)+)/; for ( i = ; i < n; i ++ ) A[i].r = A[i].i = ;
for ( i = ; i < nn; i ++ ){
if ( s[i] == 'b' ) A[i].r = ;
}
fft ( A, );
for ( i = ; i < n; i ++ ) A[i] = A[i]*A[i];
fft ( A, - );
for ( i = ; i <= *(nn-); i ++ ) ans[i] += ((int)(A[i].r+0.5)+)/;
int ret = ;
for ( i = nn-; i >= ; i -- ){
s[i*+] = s[i];
s[i*+] = '#';
}
s[] = '$'; s[] = '#'; s[*nn+] = '?';
rad[] = ;
k = ; int mx = ;
for ( i = ; i <= *nn; i ++ ){
if ( mx > i ){ rad[i] = _min ( rad[*k-i], mx-i ); }
else rad[i] = ;
while ( s[i-rad[i]] == s[i+rad[i]] && i-rad[i] > ) rad[i] ++;
if ( i+rad[i]- > mx ){
mx = i+rad[i]-;
k = i;
}
}
for ( i = ; i <= *(nn-); i ++ ){
ret = (ret+d[ans[i]]-rad[i+]/-)%Mod;
}
printf ( "%d\n", ret );
return ;
}

4.bzoj 3509: [CodeChef] COUNTARI

来自la1la1la的题解:

1)先看一个比较sb的做法:枚举$j$,算出前面$i$的数的总数$sum[a[i]]$,然后找后面$k$,统计$sum[2*a[j]-a[k]]$

时间复杂度为$O(n^2)$

2)再看一个更sb的做法:枚举$j$,搞搞前面的和后面的,FFT计算,累加和为$2*a[j]$的方案

时间复杂度为$O(n^2logn)$

好,那么这题的正解就是把这两个sb做法合在一起

用一种分块的方法,假设我们把每一块分成每块大小为$S$的

枚举块内的$j$,用第一种方法暴力搞出块内的$i$、$k$的所有情况,时间复杂度为$O(nS)$

对于块外的使用第二种方法,即FFT来做,再枚举块内$j$累计答案,时间复杂度为$O(\frac{n}{S}mlogm)$,其中$m$是最大的数

总时间是$O(nS+\frac{n}{S}mlogm)$,这个时候锅就该甩给你们的高中老师了——均值不等式:$$a+b\geq 2\sqrt{ab}$$

那么这个时间就是大于等于$2n\sqrt{mlogm}$,什么时候能取到等于呢,继续甩锅

当$S=\sqrt{mlogm}$就是了,你把$m=30000$带进去估算一下(记住要带点常数进去..)

大概就是$S=1000$就差不多了 反正我就是这么打的..

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#define LL long long
using namespace std;
const int Maxn = ;
const int Maxm = ;
struct cp {
double r, i;
cp ( double r=0.0, double i=0.0 ) : r(r), i(i) {}
}A[Maxm*], B[Maxm*]; int An;
cp operator + ( cp a, cp b ){ return cp ( a.r+b.r, a.i+b.i ); }
cp operator - ( cp a, cp b ){ return cp ( a.r-b.r, a.i-b.i ); }
cp operator * ( cp a, cp b ){ return cp ( a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r ); }
int a[Maxn], n, S;
int _min ( int x, int y ){ return x < y ? x : y; }
int _max ( int x, int y ){ return x > y ? x : y; }
int ss[Maxm*], ls[Maxm*], rs[Maxm*];
double pi = acos(-);
void fft ( cp *a, int op ){
int i, j, k;
for ( i = j = ; i < An; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = An >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
for ( i = ; i <= An; i <<= ){
cp wn = cp ( cos (2.0*pi/i), op * sin (2.0*pi/i) );
for ( j = ; j < An; j += i ){
cp w = cp ( , );
for ( k = j; k < j+i/; k ++ ){
cp x = a[k], y = w*a[k+i/];
a[k] = x+y; a[k+i/] = x-y;
w = w*wn;
}
}
}
if ( op == - ) for ( i = ; i < An; i ++ ) a[i].r /= An;
}
int main (){
int i, j, k;
scanf ( "%d", &n );
S = ;
int mx = ;
for ( i = ; i <= n; i ++ ){
scanf ( "%d", &a[i] );
rs[a[i]] ++;
mx = _max ( mx, a[i] );
}
An = ; while ( An < *mx+ ) An <<= ;
LL ans = ;
for ( i = ; i <= n; i += S ){
int p = _min ( i+S-, n );
for ( j = i; j <= p; j ++ ) rs[a[j]] --;
for ( j = i; j <= p; j ++ ){
for ( k = j+; k <= p; k ++ ){
if ( *a[j]-a[k] > ) ans += (LL)ss[*a[j]-a[k]]+ls[*a[j]-a[k]];
}
ss[a[j]] ++;
for ( k = i; k <= j-; k ++ ){
if ( *a[j]-a[k] > ) ans += (LL)rs[*a[j]-a[k]];
}
}
for ( j = ; j < An; j ++ ){
A[j].r = ls[j]; A[j].i = ;
B[j].r = rs[j]; B[j].i = ;
}
fft ( A, ); fft ( B, );
for ( j = ; j < An; j ++ ) A[j] = A[j]*B[j];
fft ( A, - );
for ( j = i; j <= p; j ++ ){
ans += (LL)(A[*a[j]].r+0.5);
}
for ( j = i; j <= p; j ++ ) ss[a[j]] --, ls[a[j]] ++;
}
printf ( "%lld\n", ans );
return ;
}

5.hdu 5730 Shell Necklace

先把原题公式化简一下就是:$$dp_i=\sum\limits_{j=1}^{i-1}a_{i-j}\times dp_j$$

那么的话就是cdq分治搞一下再套一个FFT计算更新答案了..

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
using namespace std;
const int Maxn = ;
const int Mod = ;
struct cp {
double r, i;
cp ( double r=0.0, double i=0.0 ) : r(r), i(i) {}
}A[Maxn*], B[Maxn*]; int An;
cp operator + ( cp a, cp b ){ return cp ( a.r+b.r, a.i+b.i ); }
cp operator - ( cp a, cp b ){ return cp ( a.r-b.r, a.i-b.i ); }
cp operator * ( cp a, cp b ){ return cp ( a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r ); }
int a[Maxn], n, f[Maxn];
double pi = acos (-);
void fft ( cp *a, int op ){
int i, j, k;
for ( i = j = ; i < An; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = An >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
for ( i = ; i <= An; i <<= ){
cp wn = cp ( cos (*pi/i), op * sin (*pi/i) );
for ( j = ; j < An; j += i ){
cp w = cp ( , );
for ( k = j; k < j+i/; k ++ ){
cp x = a[k], y = w*a[k+i/];
a[k] = x+y; a[k+i/] = x-y;
w = w*wn;
}
}
}
if ( op == - ) for ( i = ; i < An; i ++ ) a[i].r /= An;
}
void cdq ( int l, int r ){
if ( l == r ) return;
int mid = ( l + r ) >> ;
cdq ( l, mid );
int i, j, k;
int len = r-l;
An = ; while ( An < *len+ ) An <<= ;
for ( i = l; i <= mid; i ++ ){
A[i-l].r = f[i]; A[i-l].i = ;
B[i-l].r = a[i-l]; B[i-l].i = ;
}
for ( i = mid+; i <= r; i ++ ){
A[i-l].r = ; A[i-l].i = ;
B[i-l].r = a[i-l]; B[i-l].i = ;
}
for ( i = len+; i < An; i ++ ) A[i].r = A[i].i = B[i].r = B[i].i = ;
fft ( A, ); fft ( B, );
for ( i = ; i < An; i ++ ) A[i] = A[i]*B[i];
fft ( A, - );
for ( i = mid+; i <= r; i ++ ){
int ret = (int)(A[i-l].r+0.5);
ret %= Mod;
f[i] = ( f[i] + ret ) % Mod;
}
cdq ( mid+, r );
}
int main (){
int i, j, k;
while ( scanf ( "%d", &n ) != EOF ){
if ( n == ) break;
for ( i = ; i <= n; i ++ ){
scanf ( "%d", &a[i] ); a[i] %= Mod;
f[i] = ;
}
f[] = ;
cdq ( , n );
printf ( "%d\n", f[n] );
}
return ;
}

6.hdu 5307 He is Flying

题意:有$n$段路,每段路长度为$s_i$,你从节点$i$到节点$j$,可以获得一个开心值$j−i+1$,然后问你,主人公走过了所有总长度为$s$的段,问你有多少开心值。

%%%la1la1la

累计一个前缀和$sum$,那么题意就变成走$sum_j-sum_i$的路程获得$j-i$的开心值

那么就把$j$和$-i$分开来算,卷积一下就好了..

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#define LL long long
#define LD long double
using namespace std;
const LL Maxn = ;
const LL Maxs = ;
struct cp {
LD r, i;
cp ( LD r=0.0, LD i=0.0 ) : r(r), i(i) {}
}A[Maxs*], B[Maxs*]; LL An;
cp operator + ( cp a, cp b ){ return cp ( a.r+b.r, a.i+b.i ); }
cp operator - ( cp a, cp b ){ return cp ( a.r-b.r, a.i-b.i ); }
cp operator * ( cp a, cp b ){ return cp ( a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r ); }
LL sum[Maxn], n, ans[Maxs];
const LD pi = acos (-);
void fft ( cp *a, LL op ){
LL i, j, k;
for ( i = j = ; i < An; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = An >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
for ( i = ; i <= An; i <<= ){
for ( j = ; j < An; j += i ){
for ( k = j; k < j+i/; k ++ ){
cp x = a[k], y = cp ( cos (*pi*(k-j)/i), op * sin (*pi*(k-j)/i) ) * a[k+i/];
a[k] = x+y; a[k+i/] = x-y;
}
}
}
if ( op == - ) for ( i = ; i < An; i ++ ) a[i].r /= An;
}
LL ss[Maxn];
int main (){
LL i, j, k, T;
for ( i = ; i <= ; i ++ ) ss[i] = ss[i-]+i*(i+)/;
scanf ( "%I64d", &T );
while ( T -- ){
scanf ( "%I64d", &n );
LL lj = , ret = ;
for ( i = ; i <= n; i ++ ){
scanf ( "%I64d", &sum[i] );
if ( sum[i] == ) lj ++;
else {
ret += ss[lj];
lj = ;
}
sum[i] += sum[i-];
}
ret += ss[lj];
printf ( "%I64d\n", ret );
for ( i = ; i <= sum[n]; i ++ ) ans[i] = ;
An = ; while ( An < sum[n]* ) An <<= ;
for ( i = ; i < An; i ++ ) A[i].r = A[i].i = B[i].r = B[i].i = ;
for ( i = ; i <= n; i ++ ){
A[sum[n]-sum[i]].r += i;
B[sum[i]].r += ;
}
fft ( A, ); fft ( B, );
for ( i = ; i < An; i ++ ) A[i] = A[i]*B[i];
fft ( A, - );
for ( i = ; i <= sum[n]; i ++ ) ans[i] = -(LL)(A[sum[n]+i].r+0.5);
for ( i = ; i < An; i ++ ) A[i].r = A[i].i = B[i].r = B[i].i = ;
for ( i = ; i <= n; i ++ ){
A[sum[n]-sum[i]].r += ;
B[sum[i]].r += i;
}
fft ( A, ); fft ( B, );
for ( i = ; i < An; i ++ ) A[i] = A[i]*B[i];
fft ( A, - );
for ( i = ; i <= sum[n]; i ++ ) ans[i] += (LL)(A[sum[n]+i].r+0.5);
for ( i = ; i <= sum[n]; i ++ ) printf ( "%I64d\n", ans[i] );
}
return ;
}

7.hdu 4609/bzoj 3513 3-idiots

题意:给你$n$条边,问你任取三条边能组成三角形的概率

按权值卷积后,把边排序

对于任意一条边你都能知道其余两条边加起来大于它的方案数

考虑第$i$条边是最长边,那么就要减去一条选大的一条选小的、两条都选大的方案数,这些都可以用组合公式算得..

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#define LL long long
using namespace std;
const LL Maxn = ;
struct cp {
double r, i;
cp ( double r=0.0, double i=0.0 ) : r(r), i(i) {}
}A[Maxn*]; LL An;
cp operator + ( cp a, cp b ){ return cp ( a.r+b.r, a.i+b.i ); }
cp operator - ( cp a, cp b ){ return cp ( a.r-b.r, a.i-b.i ); }
cp operator * ( cp a, cp b ){ return cp ( a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r ); }
LL a[Maxn], n;
LL ans[Maxn*];
const double pi = acos (-);
LL _max ( LL x, LL y ){ return x > y ? x : y; }
void fft ( cp *a, LL op ){
LL i, j, k;
for ( i = j = ; i < An; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = An >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
for ( i = ; i <= An; i <<= ){
cp wn = cp ( cos (*pi/i), op * sin (*pi/i) );
for ( j = ; j < An; j += i ){
cp w = cp ( , );
for ( k = j; k < j+i/; k ++ ){
cp x = a[k], y = w*a[k+i/];
a[k] = x+y; a[k+i/] = x-y;
w = w*wn;
}
}
}
if ( op == - ) for ( i = ; i < An; i ++ ) A[i].r /= An;
}
int main (){
LL i, j, k, T;
scanf ( "%I64d", &T );
while ( T -- ){
scanf ( "%I64d", &n );
LL mx = ;
for ( i = ; i <= n; i ++ ){
scanf ( "%I64d", &a[i] );
mx = _max ( mx, a[i] );
}
An = ; while ( An < *mx+ ) An <<= ;
for ( i = ; i < An; i ++ ) A[i].r = A[i].i = ;
for ( i = ; i <= n; i ++ ){
A[a[i]].r += ;
}
fft ( A, );
for ( i = ; i < An; i ++ ) A[i] = A[i]*A[i];
fft ( A, - );
for ( i = ; i <= mx*; i ++ ) ans[i] = (LL)(A[i].r+0.5);
for ( i = ; i <= n; i ++ ) ans[*a[i]] --;
for ( i = ; i <= mx*; i ++ ) ans[i] /= ;
for ( i = mx*-; i >= ; i -- ) ans[i] += ans[i+];
sort ( a+, a+n+ );
LL ret = ;
for ( i = ; i <= n; i ++ ){
ret += (LL)ans[a[i]+]-n+;
ret -= (LL)(n-i)*(i-);
ret -= (LL)(n-i)*(n-i-)/;
}
LL zs = n*(n-)*(n-)/;
printf ( "%.7lf\n", (double)ret/zs );
}
return ;
}

8.hdu 5751 Eades

题意:

Peter有一个序列$a_1, a_2, ..., a_n$. 定义$g(l,r)$表示子序列$a_{l},a_{l+1},...,a_{r}$的最大值, $f(l,r)=\displaystyle\sum_{i=l}^{r}[a_i = g(l,r)]$. 注意$[\text{condition}] = 1$当且仅当$\text{condition}$是true, 否则$[\text{condition}] = 0$.

对于每个整数$k \in \{1, 2, ..., n\}$, Peter想要知道有多少整数对$l$和$r$ $(l \le r)$满足$f(l,r)=k$.

官方题解:

枚举每个数$x$, 考虑这个数成为最大值的对每个$z(\cdot)$的贡献. 对于每个数$x$, 你都可以求出若干个极大区间$[l_x,r_x]$, 表示这个$x$在区间内是最大值, 不妨假设这个$x$在$[l_x,r_x]$出现了$m$次, 每个位置分别为$p_1,p_2,...,p_m$, 那么就可以转化成一个长度为$m+1$的序列$c_0,c_2,...,c_m$. 其中$c_0=p_1-l+1$ $c_{m}=r-p_m+1$ $c_i=p_{i+1}-p_{i}, 1 \leq i < m$

于是对$z_k$的贡献就是$z_k=\sum_{i=0}^{m}c_i \cdot c_{i+k}$

这个其实就是$c_0,c_1,...,c_m$和$c_m,c_{m-1},...,c_0$的卷积, 做一次fft就好了.

时间复杂度分析:因为每一个数只有可能进入一次做fft,所以总时间均摊$O(nlogn)$

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <cmath>
#define LL long long
using namespace std;
const LL Maxn = ;
struct cp {
double r, i;
cp ( double r=0.0, double i=0.0 ) : r(r), i(i) {}
}A[Maxn*], B[Maxn*]; LL An;
cp operator + ( cp a, cp b ){ return cp ( a.r+b.r, a.i+b.i ); }
cp operator - ( cp a, cp b ){ return cp ( a.r-b.r, a.i-b.i ); }
cp operator * ( cp a, cp b ){ return cp ( a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r ); }
LL a[Maxn], maxx[Maxn*], n;
vector <LL> vec[Maxn];
LL _max ( LL x, LL y ){ return x > y ? x : y; }
void bulid_tree ( LL now, LL L, LL R ){
if ( L < R ){
LL mid = ( L + R ) >> , lc = now*, rc = now*+;
bulid_tree ( lc, L, mid );
bulid_tree ( rc, mid+, R );
maxx[now] = _max ( maxx[lc], maxx[rc] );
}
else maxx[now] = a[L];
}
LL query ( LL now, LL L, LL R, LL l, LL r ){
if ( L == l && R == r ) return maxx[now];
LL mid = ( L + R ) >> , lc = now*, rc = now*+;
if ( r <= mid ) return query ( lc, L, mid, l, r );
else if ( l > mid ) return query ( rc, mid+, R, l, r );
else return _max ( query ( lc, L, mid, l, mid ), query ( rc, mid+, R, mid+, r ) );
}
LL ans[Maxn];
const double pi = acos (-);
void fft ( cp *a, LL op ){
LL i, j, k;
for ( i = j = ; i < An; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = An >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
for ( i = ; i <= An; i <<= ){
cp wn = cp ( cos (*pi/i), op * sin (*pi/i) );
for ( j = ; j < An; j += i ){
cp w = cp ( , );
for ( k = j; k < j+i/; k ++ ){
cp x = a[k], y = w*a[k+i/];
a[k] = x+y; a[k+i/] = x-y;
w = w*wn;
}
}
}
if ( op == - ) for ( i = ; i < An; i ++ ) a[i].r /= An;
}
void solve ( LL l, LL r ){
if ( l > r ) return;
LL i, j, k;
LL Max = query ( , , n, l, r );
LL x = lower_bound ( vec[Max].begin (), vec[Max].end (), l ) - vec[Max].begin ();
LL y = upper_bound ( vec[Max].begin (), vec[Max].end (), r ) - vec[Max].begin ();
An = ; while ( An < *(y-x)+ ) An <<= ;
for ( i = ; i < An; i ++ ) A[i].r = A[i].i = B[i].r = B[i].i = ;
A[].r = vec[Max][x]-l+;
for ( i = x; i < y-; i ++ ){
A[i-x+].r = vec[Max][i+]-vec[Max][i];
}
A[y-x].r = r-vec[Max][y-]+;
for ( i = ; i <= y-x; i ++ ) B[i].r = A[y-x-i].r;
fft ( A, ); fft ( B, );
for ( i = ; i < An; i ++ ) A[i] = A[i]*B[i];
fft ( A, - );
for ( i = ; i <= y-x; i ++ ) ans[i] += (LL)(A[i+y-x].r+0.5);
solve ( l, vec[Max][x]- );
for ( i = x; i < y-; i ++ ) solve ( vec[Max][i]+, vec[Max][i+]- );
solve ( vec[Max][y-]+, r );
}
int main (){
LL i, j, k, T;
scanf ( "%I64d", &T );
while ( T -- ){
scanf ( "%I64d", &n );
for ( i = ; i <= n; i ++ ) vec[i].clear ();
for ( i = ; i <= n; i ++ ){ scanf ( "%I64d", &a[i] ); vec[a[i]].push_back (i); }
bulid_tree ( , , n );
for ( i = ; i <= n; i ++ ) ans[i] = ;
solve ( , n );
LL ret = ;
for ( i = ; i <= n; i ++ ){
ret += ans[i]^i;
}
printf ( "%I64d\n", ret );
}
return ;
}

9.bzoj 4332: JSOI2012 分零食

简化题意:把$m$分到$n$个位置,允许后缀为$0$,每个位置的值为$f(i)=Oi^2+Si+U$,其中$i$是该位置分到的数,问你所有情况的权值积的和

找到三种做法:

YJQ大爷

WerKeyTom (讲道理不是很理解..)

某大神 (比较厉害的做法..)

我就是用YJQ大爷的方法的.. 他的题解讲的不是很详细就来补充一下吧

设$f_{i,j}$表示把$j$分到$i$个位置而且全部填满没有后缀$0$的答案

$g_{i,j}$表示把$j$分到$i$个位置而且一定至少有一个空没填的答案

由于$n\leq 10^8$,所以肯定是要用倍增的方法

然后就是可以这么推:$$f_{i,j}=\sum\limits_{k=1}^{j-1}f_{\frac{i}{2},j-k}\cdot f_{\frac{i}{2},k}$$

$$g_{i,j}=\sum\limits_{k=1}^{j-1}g_{\frac{i}{2},j-k}\cdot f_{\frac{i}{2},k}$$

大概就是这样..

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
using namespace std;
const int Maxn = ;
int m, P, n, O, S, U;
struct cp {
double r, i;
cp ( double r=0.0, double i=0.0 ) : r(r), i(i) {}
}A[Maxn*], B[Maxn*];
cp operator + ( cp a, cp b ){ return cp ( a.r+b.r, a.i+b.i ); }
cp operator - ( cp a, cp b ){ return cp ( a.r-b.r, a.i-b.i ); }
cp operator * ( cp a, cp b ){ return cp ( a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r ); }
int f[Maxn*], g[Maxn*], ansf[Maxn*], ansg[Maxn*], An;
int rev[Maxn*];
const double pi = acos (-);
void fft ( cp *a, int op ){
int i, j, k;
for ( i = ; i < An; i ++ ) if ( i < rev[i] ) swap ( a[i], a[rev[i]] );
for ( i = ; i <= An; i <<= ){
cp wn = cp ( cos (*pi/i), op * sin (*pi/i) );
for ( j = ; j < An; j += i ){
cp w = cp ( , );
for ( k = j; k < j+i/; k ++ ){
cp x = a[k], y = w*a[k+i/];
a[k] = x+y; a[k+i/] = x-y;
w = w*wn;
}
}
}
if ( op == - ) for ( i = ; i < An; i ++ ) a[i].r /= An;
}
int main (){
int i, j, k;
scanf ( "%d%d%d%d%d%d", &m, &P, &n, &O, &S, &U );
An = ; while ( An < *m+ ) An <<= ;
j = ;
for ( i = ; i < An; i ++ ){
rev[i] = j;
k = An >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
g[] = ;
for ( i = ; i <= m; i ++ ){
f[i] = (((O*i)%P*i)%P+(S*i)%P+U)%P;
}
ansf[] = ; ansg[] = ;
while (n){
if ( n & ){
for ( i = ; i <= m; i ++ ) A[i].r = ansf[i], B[i].r = g[i], A[i].i = B[i].i = ;
for ( i = m+; i < An; i ++ ) A[i].r = B[i].r = A[i].i = B[i].i = ;
fft ( A, ); fft ( B, );
for ( i = ; i < An; i ++ ) A[i] = A[i]*B[i];
fft ( A, - );
for ( i = ; i <= m; i ++ ) ansg[i] = ( ansg[i] + ((int)(A[i].r+0.5))%P ) % P; for ( i = ; i <= m; i ++ ) A[i].r = ansf[i], B[i].r = f[i], A[i].i = B[i].i = ;
for ( i = m+; i < An; i ++ ) A[i].r = B[i].r = A[i].i = B[i].i = ;
fft ( A, ); fft ( B, );
for ( i = ; i < An; i ++ ) A[i] = A[i]*B[i];
fft ( A, - );
for ( i = ; i <= m; i ++ ) ansf[i] = ((int)(A[i].r+0.5))%P;
}
for ( i = ; i <= m; i ++ ) A[i].r = f[i], B[i].r = g[i], A[i].i = B[i].i = ;
for ( i = m+; i < An; i ++ ) A[i].r = B[i].r = A[i].i = B[i].i = ;
fft ( A, ); fft ( B, );
for ( i = ; i < An; i ++ ) A[i] = A[i]*B[i];
fft ( A, - );
for ( i = ; i <= m; i ++ ) g[i] = ( g[i] + ((int)(A[i].r+0.5))%P ) % P; for ( i = ; i <= m; i ++ ) A[i].r = f[i], A[i].i = ;
for ( i = m+; i < An; i ++ ) A[i].r = A[i].i = ;
fft ( A, );
for ( i = ; i < An; i ++ ) A[i] = A[i]*A[i];
fft ( A, - );
for ( i = ; i <= m; i ++ ) f[i] = ((int)(A[i].r+0.5))%P;
n >>= ;
}
printf ( "%d\n", (ansf[m]+ansg[m])%P );
return ;
}

10.bzoj 3456: 城市规划

这是一道比较厉害的NTT题目.. 我不是很会解释,建议去看看WerKeyTom的blog

这道题的时间也比较坑,我卡了很久的常数才过..

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <ctime>
#include <iostream>
#define LL long long
using namespace std;
const LL Maxn = ;
const LL Mod = ;
const LL G = ;
LL A[Maxn*], B[Maxn*], An;
LL pow ( LL x, LL k ){
LL ret = ;
while (k){
if ( k & ) ret = (ret*x)%Mod;
x = (x*x)%Mod;
k >>= ;
}
return ret;
}
LL v[Maxn], f[Maxn], jc[Maxn], fjc[Maxn], n, inv[Maxn], cf[Maxn*], fcf[Maxn*], p[Maxn];
void dft ( LL *a, LL op ){
LL i, j, k;
for ( i = j = ; i < An; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = An >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
for ( i = ; i <= An; i <<= ){
LL wn = op > ? cf[i] : fcf[i];
LL w = ;
for ( k = ; k < i/; k ++ ){
for ( j = ; j < An; j += i ){
LL x = a[k+j], y = (w*a[k+j+i/])%Mod;
a[k+j] = (x+y)%Mod; a[k+j+i/] = (x-y+Mod)%Mod;
}
w = (w*wn)%Mod;
}
}
if ( op == - ) for ( i = ; i < An; i ++ ) a[i] = (a[i]*inv[An])%Mod;
}
void cdq ( LL l, LL r ){
if ( l == r ){ f[l] = (v[l]-(jc[l-]*f[l])%Mod+Mod)%Mod; return; }
LL mid = ( l + r ) >> , i;
cdq ( l, mid );
LL len = r-l+;
An = ; while ( An < *len+ ) An <<= ;
for ( i = ; i < An; i ++ ) A[i] = B[i] = ;
for ( i = ; i < len; i ++ ) B[i] = p[i];
for ( i = ; i < mid-l+; i ++ ) A[i] = (f[i+l]*fjc[i+l-])%Mod;
dft ( A, ); dft ( B, );
for ( i = ; i < An; i ++ ) A[i] = (A[i]*B[i])%Mod;
dft ( A, - );
for ( i = mid+; i <= r; i ++ ) f[i] = (f[i]+A[i-l])%Mod;
cdq ( mid+, r );
}
int c[Maxn];
int main (){
LL i, j, k;
scanf ( "%lld", &n );
jc[] = ; fjc[] = ; v[] = ;
c[] = ; for ( i = ; i <= n; i ++ ) c[i] = (c[i-]*)%Mod;
for ( i = ; i <= n; i ++ ){
v[i] = ( v[i-] * c[i-] ) % Mod;
jc[i] = (jc[i-]*i)%Mod;
}
fjc[n] = pow ( jc[n], Mod- );
for ( i = n-; i >= ; i -- ){
fjc[i] = (fjc[i+]*(i+))%Mod;
p[i] = (v[i]*fjc[i])%Mod;
}
An = ; while ( An < *n+ ) An <<= ;
for ( i = ; i <= An; i <<= ) cf[i] = pow ( G, (Mod-)/i ), fcf[i] = pow ( G, Mod--(Mod-)/i ), inv[i] = pow ( i, Mod- );
cdq ( , n );
printf ( "%lld\n", f[n] );
return ;
}

11.bzoj 4555: [Tjoi2016&Heoi2016]求和

把公式化一下基本就出来了,这里给出几个公式:

斯特林数化简公式:$$S(n,m)=\dfrac{1}{m!}\sum\limits_{k=0}^{m}(-1)^k\cdot C_m^k\cdot (m-k)^n$$

等比数列求和公式:$$Sum=\dfrac{a(q^k-1)}{q-1}$$

那么最后化成的公式为:$$f(n)=\sum\limits_{j=0}^n2^j\cdot j!\sum\limits_{k=0}^jg(k)\cdot h(j-k)$$

$$g(i)=\dfrac{(-1)^i}{i!},\ h(i)=\dfrac{\sum_{k=0}^{n}i^k}{i!}$$

 #include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define LL long long
using namespace std;
const LL Maxn = ;
const LL Mod = ;
const LL G = ;
LL jc[Maxn], inv[Maxn], invn;
LL pow ( LL x, LL k ){
LL ret = ;
while (k){
if ( k & ) ret = (ret*x)%Mod;
x = (x*x)%Mod;
k >>= ;
}
return ret;
}
LL A[Maxn*], B[Maxn*], An;
LL n;
LL dft ( LL *a, LL op ){
LL i, j, k;
for ( i = j = ; i < An; i ++ ){
if ( i < j ) swap ( a[i], a[j] );
k = An >> ;
while ( j & k ) j -= k, k >>= ;
j += k;
}
for ( i = ; i <= An; i <<= ){
LL wn = pow ( G, op > ? (Mod-)/i : Mod--(Mod-)/i );
for ( j = ; j < An; j += i ){
LL w = ;
for ( k = j; k < j+i/; k ++ ){
LL x = a[k], y = (w*a[k+i/])%Mod;
a[k] = (x+y)%Mod; a[k+i/] = (x-y+Mod)%Mod;
w = (w*wn)%Mod;
}
}
}
if ( op == - ) for ( i = ; i < An; i ++ ) a[i] = (a[i]*invn)%Mod;
}
int main (){
LL i, j, k;
scanf ( "%lld", &n );
An = ; while ( An < *n+ ) An <<= ;
jc[] = jc[] = ;
for ( i = ; i <= n; i ++ ) jc[i] = (jc[i-]*i)%Mod;
inv[n] = pow ( jc[n], Mod- );
for ( i = n-; i >= ; i -- ) inv[i] = (inv[i+]*(i+))%Mod;
A[] = ;
for ( i = ; i <= n; i ++ ) A[i] = i % == ? Mod-inv[i] : inv[i];
B[] = ; B[] = n+;
for ( i = ; i <= n; i ++ ) B[i] = ( ( ( pow ( i, n+ ) - ) * pow ( i-, Mod- ) ) % Mod * inv[i] ) % Mod;
invn = pow ( An, Mod- );
dft ( A, ); dft ( B, );
for ( i = ; i < An; i ++ ) A[i] = (A[i]*B[i])%Mod;
dft ( A, - );
LL s = , ans = ;
for ( i = ; i <= n; i ++ ){
ans = ( ans + ((s*jc[i])%Mod*A[i])%Mod ) % Mod;
s = (s*)%Mod;
}
printf ( "%lld\n", ans );
return ;
}

参考资料

picks的blog