Codeforces960G Bandit Blues 【斯特林数】【FFT】

时间:2023-03-08 15:35:12

题目大意:

  求满足比之前的任何数小的有A个,比之后的任何数小的有B个的长度为n的排列个数。

题目分析:

  首先写出递推式,设s(n,k)表示长度为n的排列,比之前的数小的数有k个。

  我们假设新加入的数为1,那么s(n,k)=s(n-1,k-1)+(n-1)*s(n,k)。

  这个式子是第一类斯特林数的递推式。

  用h(n,a,b)表示满足题目给出条件的排列个数。

  得出h(n,a,b)=Σs(k,a-1)*s(n-k-1,b-1)*C(n-1,k)。直观的理解就是将原排列从最高点分成两部分,两部分分别组合然后乘起来。

  这样我们发现h(n,a,b)=s(n-1,a+b-2)*C(a+b-2,a-1)。这实际上就是给出一个a+b-2的排列,然后选出其中需要的点放到右边,我们不用考虑多余的点,因为它们的排列已经被计算。

  由于无符号第一类斯特林数对应着升幂的系数,构造x(x+1)(x+2)...(x+n-1),它的x^k的系数等于s(n,k)的值,由于最高项系数为1,所以分治FFT。

代码:

  

 #include<bits/stdc++.h>
#pragma GCC optimize(2)
using namespace std; const int mod = ;
const int gg = ; int n,a,b; vector<int> res[]; int up[]; int ord[]; int fast_pow(int now,int pw){
if(pw == ) return ;
if(pw == ) return now;
int z = fast_pow(now,pw/);
z = (1ll*z*z)%mod;
if(pw & ){z= (1ll*z*now)%mod;}
return z;
} void fft(int now,int len,int f){
for(int i=;i<len;i++) if(i<ord[i]) swap(res[now][i],res[now][ord[i]]);
for(int i=;i<len;i<<=){
int wn = fast_pow(gg,(mod-)/(i<<));
if(f == -) wn = fast_pow(wn,mod-);
for(int j=;j<len;j+=(i<<)){
for(int k=,w=;k<i;k++,w = (1ll*w*wn)%mod){
int x = res[now][j+k],y = (1ll*w*res[now][j+k+i])%mod;
res[now][j+k] = (x+y)%mod;
res[now][j+k+i] = (x-y+mod)%mod;
}
}
}
if(f == -){
int iv = fast_pow(len,mod-);
for(int i=;i<len;i++) res[now][i] = (1ll*res[now][i]*iv)%mod;
}
} void multi(int p1,int p2){
int n1 = res[p1].size()-,n2 = res[p2].size()-;
int len = ,om = ;
while(len <= (n1+n2+))len<<=,om++;
for(int i=n1+;i<len;i++) res[p1].push_back();
for(int i=n2+;i<len;i++) res[p2].push_back();
for(int i=;i<len;i++) ord[i] = (ord[i>>]>>)+((i&)<<om-);
fft(p1,len,);fft(p2,len,);
for(int i=;i<len;i++){
res[p1][i] = (1ll*res[p1][i]*res[p2][i])%mod;
if(res[p1][i] < ) res[p1][i]+=mod;
}
fft(p1,len,-);
res[p2].clear();
} void divide(int l,int r,int now){
if(l == r) {up[now] = l;return;}
int mid = (l+r)/;
divide(l,mid,now<<);
divide(mid+,r,now<<|);
multi(up[now<<],up[now<<|]);
up[now] = up[now<<];
} void work(){
if(a == || b == ){puts("");return;}
if(n == ){if(a+b==)puts(""); else puts(""); return;}
int c = ;
if(a<b) swap(a,b);
if(a- > a+b-) c = ;
for(int i=;i<=a-;i++){
c = (1ll*c*(a+b--i))%mod;
c = (1ll*c*fast_pow(i,mod-))%mod;
}
for(int i=;i<n;i++) res[i].push_back(i-),res[i].push_back();
divide(,n-,);
c = (1ll*c*res[up[]][a+b-])%mod;
printf("%d",c);
} int main(){
scanf("%d%d%d",&n,&a,&b);
work();
return ;
}