[UOJ310] 黎明前的巧克力

时间:2022-11-17 21:28:02

Sol

某比赛搬了这题。

首先选择两个不交非空子集且异或和为0的方案数,等价于选择一个异或和为0的集合,并把它分成两部分的方案数。
这显然可以DP来算,设 \(f[i][j]\) 表示前\(i\)个数异或和为\(j\)的方案数,那么转移就是 \(f[i][j]=f[i-1][j]+2\cdot f[i-1][j\;\text{xor}\;a[i]]\)

如果设 \(b_i[0]=1,b_i[a[i]]=2,b_i[j]=0\),那么这个转移就是求\(f\)\(b_i\;\text{xor}\)卷积的过程,可以用FWT优化,但是复杂度似乎更爆炸了。

如果我们可以把每个\(b\) FWT之后的结果都求出来并乘在一起,最后在对应位置乘到\(f\)上,再把\(f\) IFWT回去不就好了嘛!

如果把\(b_i\)数组FWT之后的结果打印出来,会发现所有位置不是\(3\)就是\(-1\),大概是因为这个\(2\)对每一项的贡献要么是\(2\)要么是\(-2\)

我们可以先把\(b_i\)数组整个加起来,对它做一次FWT。

因为FWT的和等于和的FWT。对于FWT之后的第\(i\)\(s\),设这位有\(x\)个数为\(-1\),那么就有\(n-x\)个数为 \(3\),且\(3(n-x)-x=s\),解得 \(x=\large \frac{3n-s}4\) 。那么FWT之后这一项的值就是 \((-1)^x3^{n-x}\)

然后乘到\(f\)上再IFWT回去就行了。

(uoj被卡了我不知道这代码能过否
(mp数组开小了,已经改过来了

Code

#include<set>
#include<map>
#include<queue>
#include<cmath>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
typedef double db;
typedef long long ll;
const int N=1048578;
const int maxn=1048576;
const int mod=998244353;
const int inv2=(mod+1)/2;

int n,f[N],b[N],po[N];

void Mul(int &x,int y){x=1ll*x*y%mod;}
int mul(int x,int y){return 1ll*x*y%mod;}
void Dec(int &x,int y){x=x-y<0?x+mod-y:x-y;}
int dec(int x,int y){return x-y<0?x+mod-y:x-y;}
void Inc(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
int inc(int x,int y){return x+y>=mod?x+y-mod:x+y;}

int ksm(int a,int b=mod-2,int ans=1){
    while(b){
        if(b&1) ans=mul(ans,a);
        a=mul(a,a); b>>=1; 
    } return ans;
}

void fwt(int *f,int opt){
    for(int mid=1;mid<maxn;mid<<=1){
        for(int R=mid<<1,j=0;j<maxn;j+=R){
            for(int k=0;k<mid;k++){
                int x=f[j+k],y=f[j+k+mid];
                f[j+k]=inc(x,y),f[j+k+mid]=dec(x,y);
                if(opt<1) Mul(f[j+k],inv2),Mul(f[j+k+mid],inv2);
            }
        }
    }
}

signed main(){
    scanf("%d",&n); f[0]=1; fwt(f,1);
    po[0]=1; for(int i=1;i<=n;i++) po[i]=mul(po[i-1],3);
    for(int x,i=1;i<=n;i++)
        scanf("%d",&x),b[0]++,b[x]+=2;
    fwt(b,1); int ni=ksm(4);
    for(int i=0;i<maxn;i++){
        int x=mul(dec(n*3,b[i]),ni);
        Mul(f[i],x&1?mod-po[n-x]:po[n-x]);
    } fwt(f,-1); printf("%d\n",dec(f[0],1));
}