[TJOI2019]唱、跳、rap和篮球——NTT+生成函数+容斥

时间:2021-04-19 08:31:32

题目链接:

[TJOI2019]唱、跳、rap和篮球

直接求不好求,我们考虑容斥,求出至少有$i$个聚集区间的方案数$ans_{i}$,那么最终答案就是$\sum\limits_{i=0}^{n}(-1)^i\ ans_{i}$

那么现在只需要考虑至少有$i$个聚集区间的方案数,我们枚举这$i$个区间的起始点位置,一共有$C_{n-3i}^{i}$种方案(可以看作是刚开始先将每个区间后三个位置去掉,从剩下$n-3i$个位置中选出$i$个区间起点,然后再在每个起点后面加上三个位置)。

那么剩下的$n-4i$个位置就是随便放这四种学生,假设第$j$种学生放了$a_{j}$个、一共有$num_{j}$个,那么方案数就是$\frac{(n-4i)!}{\prod_{j=1}^{4}a_{j}!}$。

由此可以构造出这四种学生的生成函数,以第一种学生为例:$\sum\limits_{j=0}^{num_{1}-i}\frac{x^j}{j!}$

将四个生成函数分别用$NTT$乘在一起然后取$x^{n-4i}$前的系数乘上$(n-4i)!$即可得到$n-4i$个位置随便放的方案数。

#include<set>
#include<map>
#include<cmath>
#include<stack>
#include<queue>
#include<bitset>
#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int mod=998244353;
int f[3000];
int g[3000];
int inv[2000];
int fac[2000];
int mask;
int n,a,b,c,d;
int ans;
int mn,mx;
int quick(int x,int y)
{
int res=1;
while(y)
{
if(y&1)
{
res=1ll*res*x%mod;
}
x=1ll*x*x%mod;
y>>=1;
}
return res;
}
void NTT(int *a,int len,int opt)
{
for(int i=0,k=0;i<len;i++)
{
if(i>k)
{
swap(a[i],a[k]);
}
for(int j=len>>1;(k^=j)<j;j>>=1);
}
for(int i=2;i<=len;i<<=1)
{
int t=i>>1;
int x=quick(3,(mod-1)/i);
if(opt==-1)
{
x=quick(x,mod-2);
}
for(int j=0;j<len;j+=i)
{
int w=1;
for(int k=j;k<j+t;k++)
{
int tmp=1ll*a[k+t]*w%mod;
a[k+t]=(a[k]-tmp+mod)%mod;
a[k]=(a[k]+tmp)%mod;
w=1ll*w*x%mod;
}
}
}
if(opt==-1)
{
int x=quick(len,mod-2);
for(int i=0;i<len;i++)
{
a[i]=1ll*a[i]*x%mod;
}
}
}
int C(int n,int m)
{
return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int solve(int x)
{
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
for(int i=0;i<=a-x;i++)
{
f[i]=inv[i];
}
for(int i=0;i<=b-x;i++)
{
g[i]=inv[i];
}
NTT(f,mask,1);
NTT(g,mask,1);
for(int i=0;i<mask;i++)
{
f[i]=1ll*f[i]*g[i]%mod;
}
memset(g,0,sizeof(g));
for(int i=0;i<=c-x;i++)
{
g[i]=inv[i];
}
NTT(g,mask,1);
for(int i=0;i<mask;i++)
{
f[i]=1ll*f[i]*g[i]%mod;
}
memset(g,0,sizeof(g));
for(int i=0;i<=d-x;i++)
{
g[i]=inv[i];
}
NTT(g,mask,1);
for(int i=0;i<mask;i++)
{
f[i]=1ll*f[i]*g[i]%mod;
}
NTT(f,mask,-1);
return 1ll*f[n-4*x]*fac[n-4*x]%mod*C(n-3*x,x)%mod;
}
int main()
{
inv[1]=inv[0]=fac[0]=1;
for(int i=1;i<=1000;i++)
{
fac[i]=1ll*fac[i-1]*i%mod;
}
for(int i=2;i<=1000;i++)
{
inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
}
for(int i=1;i<=1000;i++)
{
inv[i]=1ll*inv[i-1]*inv[i]%mod;
}
mask=1;
scanf("%d%d%d%d%d",&n,&a,&b,&c,&d);
mn=min(n/4,min(min(a,b),min(c,d)));
mx=max(max(a,b),max(c,d));
while(mask<=(mx<<2))
{
mask<<=1;
}
for(int i=0;i<=mn;i++)
{
if(i&1)
{
ans=(ans-solve(i)+mod)%mod;
}
else
{
ans=(ans+solve(i))%mod;
}
}
printf("%d",ans);
}