![UOJ310. 【UNR #2】黎明前的巧克力 [FWT] UOJ310. 【UNR #2】黎明前的巧克力 [FWT]](https://image.shishitao.com:8440/aHR0cHM6Ly9ia3FzaW1nLmlrYWZhbi5jb20vdXBsb2FkL2NoYXRncHQtcy5wbmc%2FIQ%3D%3D.png?!?w=700&webp=1)
思路
显然可以转化一下,变成统计异或起来等于0的集合个数,这样一个集合的贡献是\(2^{|S|}\)。
考虑朴素的\(dp_{i,j}\)表示前\(i\)个数凑出了\(j\)的方案数,发现这其实就是一堆多项式用异或卷积搞起来。第\(i\)个多项式是\(1+2x^{a_i}\)。
对\(1+2x^{a}\)FWT一下,发现结果就只有-1
和3
。为什么?根据FWT的理论,\(a_i\)会对\(FWT(a)_j\)产生\(a_i\times (-1)^{\text{bitcnt}[i\&j]}\)的贡献。
我们就是要求出最后这一堆东西乘在一起是什么,也就是对于每一位求出这里有几个-1,有几个3。
这个怎么做?脑洞一下,把所有多项式加在一起FWT,设有\(x\)个-1,那么就有方程\(-x+3(n-x)=f_i\),就可以解出来了。
最后再FWT回去,就做完了。
(这个解方程咋想到的啊qwq)
代码
最后减1不取模你人就没了qwq
#include<bits/stdc++.h>
clock_t t=clock();
namespace my_std{
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
#define MP make_pair
#define rep(i,x,y) for (int i=(x);i<=(y);i++)
#define drep(i,x,y) for (int i=(x);i>=(y);i--)
#define go(x) for (int i=head[x];i;i=edge[i].nxt)
#define templ template<typename T>
#define sz 1100000
#define mod 998244353ll
typedef long long ll;
typedef double db;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
templ inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
templ inline bool chkmax(T &x,T y){return x<y?x=y,1:0;}
templ inline bool chkmin(T &x,T y){return x>y?x=y,1:0;}
templ inline void read(T& t)
{
t=0;char f=0,ch=getchar();double d=0.1;
while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
t=(f?-t:t);
}
template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
char __sr[1<<21],__z[20];int __C=-1,__zz=0;
inline void Ot(){fwrite(__sr,1,__C+1,stdout),__C=-1;}
inline void print(register int x)
{
if(__C>1<<20)Ot();if(x<0)__sr[++__C]='-',x=-x;
while(__z[++__zz]=x%10+48,x/=10);
while(__sr[++__C]=__z[__zz],--__zz);__sr[++__C]='\n';
}
void file()
{
#ifdef NTFOrz
freopen("a.in","r",stdin);
#endif
}
inline void chktime()
{
#ifndef ONLINE_JUDGE
cout<<(clock()-t)/1000.0<<'\n';
#endif
}
#ifdef mod
ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;return ret;}
ll inv(ll x){return ksm(x,mod-2);}
#else
ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;return ret;}
#endif
// inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std;
int n;
int a[sz];
int f[sz];
ll g[sz];
void FWT(int *a,int n)
{
int N=1<<n,x,y;
rep(i,0,n-1)
for (int mid=1<<i,j=0;j<N;j+=mid<<1)
rep(k,0,mid-1)
x=a[j+k],y=a[j+k+mid],a[j+k]=x+y,a[j+k+mid]=x-y;
}
ll I=inv(2);
void iFWT(ll *a,int n)
{
int N=1<<n;ll x,y;
rep(i,0,n-1)
for (int mid=1<<i,j=0;j<N;j+=mid<<1)
rep(k,0,mid-1)
x=a[j+k],y=a[j+k+mid],a[j+k]=(x+y)*I%mod,a[j+k+mid]=(x-y+mod)*I%mod;
}
int main()
{
file();
read(n);
rep(i,1,n) read(a[i]),++f[0],f[a[i]]+=2;
FWT(f,20);
int x;
rep(i,0,(1<<20)-1) x=(3*n-f[i])/4,g[i]=ksm(mod-1,x)*ksm(3,n-x)%mod;
iFWT(g,20);
printf("%lld\n",(g[0]-1+mod)%mod);
return 0;
}
扩展
updated on 2020.2.1
这个解方程的思路其实不止在这题可以用到,还有luogu5577。
同样是要把一堆多项式乘在一起,但扩展到\(k\)阶循环卷积。
我们还是想要把所有东西加在一起之后求出每个\(\omega\)的系数,但是现在解方程就不那么容易。为了方便起见,下面FWT只对\(\sum x^{a_i}\)做。
假设现在在算\(f_i\)的组成。令\(y=\omega_k\),有\(f_i=\sum_{j=0}^{k-1} x_{i,j} y^j\),其中\(x_{i,j}\)是我们真实想要的东西。
但是,FWT出来的系数不一定是真的\(x_{i,j}\),因为有万恶的折半引理、求和引理、消去引理……
对于\(k\)是奇质数的情况,发现只有求和引理还有效。由于我们知道\(\sum x_{i,j}=n\),所以可以求出到底消去了多少个\(\sum_j y^j\),补上就完事了。
对于\(k\)的性质不那么好的情况呢?我们考虑大力解方程。令\(y=\omega^0,\omega^1,\cdots,\omega^{k-1}\),分别求出一大堆\(x\)。由于范德蒙德矩阵的存在,可以很轻松地解出真正的\(x_{i,j}\),就做完了。