Loj #6503. 「雅礼集训 2018 Day4」Magic
题目描述
前进!前进!不择手段地前进!——托马斯 · 维德
魔法纪元元年。
1453 年 5 月 3 日 16 时,高维碎片接触地球。
1453 年 5 月 28 日 21 时,碎片完全离开地球。
1453 年,君士坦丁堡被围城,迪奥娜拉接触到四维泡沫空间,成为魔法师,最终因高维碎片消失失去魔力而身死。
为了改写这段历史,你不惜耗费你珍藏已久的魔术卡来回到魔法纪元元年。
在使用这些魔术卡之前,你却对它们的排列起了兴趣...
桌面上摆放着 \(m\) 种魔术卡,共 \(n\) 张,第 \(i\) 种魔术卡数量为 \(a_i\),魔术卡顺次摆放,形成一个长度为 \(n\) 的魔术序列,在魔术序列中,若两张相邻魔术卡的种类相同,则它们被称为一个魔术对。
两个魔术序列本质不同,当且仅当存在至少一个位置,使得两个魔术序列这个位置上的魔术卡的种类不同,求本质不同的恰好包含 \(k\) 个魔术对的魔术序列的数量,答案对 \(998244353\) 取模。
输入格式
第一行三个整数 \(m, n, k\)。
第二行 \(m\) 个正整数,第 \(i\) 个正整数表示 \(a_i\)。
输出格式
一行一个整数表示答案。
数据范围与提示
对于 \(100 \%\) 的数据满足 \(1 \leq m \leq 20000, 0 \leq k \leq n \leq 100000, \sum_{i = 1}^{m} a_i = n\)。
首先假设同种颜色的卡片是有标号的,因为这样要好做得多。最后算出来的答案还要乘上\(\frac{1}{\prod_{i=1}^m {a_i!}}\)。
然后就可以考虑容斥了。设\(g_i\)表示强制有\(i\)个魔术对的方案数,答案为\(\sum_{i=k}^n(-1)^{i-k}\binom{k}{i}g_i\)。
对于第\(i\)种卡片,我们要先求出强制要求有\(x\)对魔术对的方案数。这个问题等价于将\(a_i\)张卡片分成\(a_i-x\)个排列,答案为\(\frac{a_i!\cdot \binom{a_i-1}{a_i-x-1}}{(a_i-x)!}\)。所以对于第\(i\)中卡片,将其分为\(x\)个排列的生成函数为
\]
将所有的\(f\)乘起来就得到了\(g\),具体实现可以用分治\(NTT\)。
\]
因为这些排列之间的相对位置也需要确定,所以\(p\)个排列的方案数要乘上\(p!\)。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 100005
#define M 20005
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
const ll mod=998244353;
ll ksm(ll t,ll x) {
ll ans=1;
for(;x;x>>=1,t=t*t%mod)
if(x&1) ans=ans*t%mod;
return ans;
}
int m,n,k;
int a[M];
ll W[20][N<<2];
void pre(int s) {
for(int i=1;i<=s;i++) {
int len=1<<i;
ll t=ksm(3,(mod-1)/len);
W[i][0]=1;
for(int j=1;j<=len;j++) W[i][j]=W[i][j-1]*t%mod;
}
}
void NTT(ll *a,int d,int flag) {
static int rev[N<<2];
int n=1<<d;
for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int s=1;s<=d;s++) {
int len=1<<s,mid=len>>1;
for(int i=0;i<n;i+=len) {
for(int j=0;j<mid;j++) {
ll t=flag==1?W[s][j]:W[s][len-j];
ll u=a[i+j],v=a[i+j+mid]*t%mod;
a[i+j]=(u+v)%mod;
a[i+j+mid]=(u-v+mod)%mod;
}
}
}
if(flag==-1) {
ll inv=ksm(n,mod-2);
for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
}
}
ll fac[N],ifac[N];
ll C(int n,int m) {return n<m?0:fac[n]*ifac[m]%mod*ifac[n-m]%mod;}
vector<int>st[M],g;
int size[M];
int Find(int L,int R) {
int l=L,r=R,mid;
while(l<r) {
mid=l+r+1>>1;
if(size[mid]-size[L-1]<=size[R]-size[mid-1]) l=mid;
else r=mid-1;
}
return l;
}
vector<int>solve(int l,int r) {
static ll A[N<<2],B[N<<2];
if(l==r) return st[l];
int mid=Find(l,r),tot=size[r]-size[l-1];
vector<int>L=solve(l,mid),R=solve(mid+1,r),ans;
ans.clear();
int d=ceil(log2(tot+1));
for(int i=0;i<1<<d;i++) A[i]=B[i]=0;
for(int i=0;i<L.size();i++) A[i]=L[i];
for(int i=0;i<R.size();i++) B[i]=R[i];
NTT(A,d,1),NTT(B,d,1);
for(int i=0;i<1<<d;i++) A[i]=A[i]*B[i]%mod;
NTT(A,d,-1);
for(int i=0;i<=tot;i++) ans.push_back(A[i]);
return ans;
}
bool cmp(int x,int y) {return x<y;}
ll ans;
int main() {
pre(18);
m=Get(),n=Get(),k=Get();
fac[0]=1;
for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod;
ifac[n]=ksm(fac[n],mod-2);
for(int i=n-1;i>=0;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
for(int i=1;i<=m;i++) a[i]=Get();
sort(a+1,a+1+m,cmp);
for(int i=1;i<=m;i++) {
st[i].push_back(0);
for(int j=1;j<=a[i];j++) {
st[i].push_back(C(a[i],j)*fac[a[i]-1]%mod*ifac[j-1]%mod);
}
}
for(int i=1;i<=m;i++) size[i]=size[i-1]+a[i];
g=solve(1,m);
for(int i=0;i<=n;i++) g[i]=g[i]*fac[i]%mod;
for(int i=0;i<=n-k;i++) {
if((k-(n-i))&1) ans-=C(n-i,k)*g[i]%mod;
else ans+=C(n-i,k)*g[i]%mod;
}
ans=(ans%mod+mod)%mod;
for(int i=1;i<=m;i++) ans=ans*ifac[a[i]]%mod;
cout<<ans;
return 0;
}