子集卷积学习笔记

时间:2024-10-07 18:59:54

前言

子集卷积,指的是 c i = ∑ j ⊕ k = i a j × b k c_i=\sum\limits_{j\oplus k=i}a_j\times b_k ci=jk=iaj×bk,其中 ⊕ \oplus 指的是 ∪ , ∩ , x o r \cup,\cap,xor ,,xor
在这里,把 i , j , k i,j,k i,j,k二进制看作集合。

或卷积

现在我们先求 c i = ∑ j ∪ k = i a j × b k c_i=\sum\limits_{j\cup k = i}a_j\times b_k ci=jk=iaj×bk
朴素的求法是 O ( n 2 ) O(n^2) O(n2) 的。
我们回忆 F F T FFT FFT 是怎么加速卷积过程的。 F F T FFT FFT 是快速得到多项式的一个点值表示,其实是找到一种变换,满足 C i ′ = A i ′ × B i ′ C'_i=A'_i\times B'_i Ci=Ai×Bi,让它能够在 O ( n ) O(n) O(n) 的时间内得到 C C C 的点值表示。
在这里,我们也用这个思想。不妨令 A ′ A' A 表示 A A A 的变换。
注意到(我也不知道怎么想出来的),当 a i ′ = ∑ j ⊆ i a j a'_i=\sum\limits_{j\subseteq i}a_j ai=jiaj c i ′ = a i ′ × b i ′ c'_i=a'_i\times b'_i ci=ai×bi
a i ′ a'_i ai 其实就是子集和。我们稍微证明一下。

c i ′ = ∑ d ⊆ i c d = ∑ d ⊆ i ∑ j ∪ k = d a j × b k = ∑ j ∪ i , k ∪ i a j × b k = ( ∑ j ∪ i a j ) × ( ∑ k ∪ i b k ) = a i ′ × b i ′ c'_i=\sum\limits_{d\subseteq i}c_d=\sum\limits_{d\subseteq i}\sum\limits_{j\cup k=d}a_j\times b_k=\sum\limits_{j\cup i,k\cup i}a_j\times b_k=(\sum\limits_{j\cup i}a_j)\times(\sum\limits_{k\cup i}b_k)=a'_i\times b'_i ci=dicd=dijk=daj×bk=ji,kiaj×bk=(jiaj)×(kibk)=ai×bi

这样的话,我们的任务就变成了,快速求 a ′ , b ′ a',b' a,b,也就是快速求 a i ′ = ∑ j ⊆ i a j a'_i=\sum\limits_{j\subseteq i}a_j ai=jiaj

考虑长度为 2 n 2^n 2n 的序列 a 0 , a 1 , . . . a 2 n − 1 a_0,a_1,...a_{2^n-1} a0,a1,...a2n1
我们把它分为前一半和后一半,记为 A 0 A_0 A0 A 1 A_1 A1
那么 A ′ = m e r g e ( A 0 ′ , A 0 ′ + A 1 ′ ) A'=merge(A_0',A_0'+A_1') A=merge(A0,A0+A1),其中 m e r g e merge merge 指的是两个序列合并,加号指的是对应位置相加。
所以我们可以递归实现这个过程。但知道了原理后,我们就可以用自底向上的循环来实现这个过程。

代码长这样(和FFT很像,很好记)
void fwtor(int *a, int type){
	int j, R, k, mid;
	for(mid = 1; mid < n; mid <<= 1){
		for(R = mid << 1, j = 0; j < n; j += R){
			for(k = 0; k < mid; k++){
				a[j + k + mid] = (a[j + k + mid] + a[j + k] * type) % mod;
			}
		}
	}
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

得到了 C ′ C' C,又该怎么还原出 C C C 呢?把对应位置的贡献减掉就行了,其实就是把加变成减。
复杂度是 O ( n l o g n ) O(nlogn) O(nlogn) 的。

与卷积

c i = ∑ j ∩ k = i a j × b k c_i=\sum\limits_{j\cap k = i}a_j\times b_k ci=jk=iaj×bk
和或卷积类似,不过 A ′ = m e r g e ( A 0 ′ + A 1 ′ , A 1 ′ ) A'=merge(A_0'+A_1',A_1') A=merge(A0+A1,A1)。(因为 A ′ A' A 是补集和)
所以代码长这样:

void fwtand(int *a, int type){
	int j, R, k, mid;
	for(mid = 1; mid < n; mid <<= 1){
		for(R = mid << 1, j = 0; j < n; j += R){
			for(k = 0; k < mid; k++){
				a[j + k] = (a[j + k] + a[j + k + mid] * type) % mod;
			}
		}
	}
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

异或卷积

a ⊗ b a \otimes b ab 表示 ∣ a ∩ b ∣   m o d   2 |a\cap b| ~mod~2 ab mod 2,也就是 a , b a,b a,b 交集大小的奇偶性。
F W T ( A ) i = ∑ j ⊗ i = 0 a j − ∑ j ⊗ i = 1 a j FWT(A)_i=\sum\limits_{j\otimes i=0}a_j-\sum\limits_{j\otimes i=1}a_j FWT(A)i=ji=0ajji=1aj

不难推导(挺难推导) F W T ( C ) i = F W T ( A ) i × F W T ( B ) i FWT(C)_i=FWT(A)_i\times FWT(B)_i FWT(C)i=FWT(A)i×FWT(B)i

考虑怎么求 F W T ( A ) FWT(A) FWT(A) I F W T ( A ′ ) IFWT(A') IFWT(A)
仍然考虑分治,记前一半和后一半为 A 0 A_0 A0 A 1 A_1 A1
那么,考虑最终序列的 F W T FWT FWT
考虑前面一半,由于 A 0 A_0 A0 A 0 A_0 A0 A 1 A_1 A1 取交集最高位都是 0 0 0,这样交集的奇偶性不会发生变化,因此前一半是 F W T ( A 0 ) + F W T ( A 1 ) FWT(A_0)+FWT(A_1) FWT(A0)+FWT(A1)
考虑后面一半,由于 A 0 A_0 A0 A 1 A_1 A1 取交集最高位是 0 0 0,因此 A 0 A_0 A0 的贡献就是 F W T ( A 0 ) FWT(A_0) FWT(A0) A 1 A_1 A1 A 1 A_1 A1 取交集最高位是 1 1 1(奇偶性发生了变化),因此贡献是 − F W T ( A 1 ) -FWT(A_1) FWT(A1)
所以 F W T ( A ) = m e r g e ( F W T ( A 0 ) + F W T ( A 1 ) , F W T ( A 0 ) − F W T ( A 1 ) ) FWT(A)=merge(FWT(A_0)+FWT(A_1),FWT(A_0)-FWT(A_1)) FWT(A)=merge(FWT(A0)+FWT(A1),FWT(A0)FWT(A1))
不难反推出 I F W T ( A ′ ) IFWT(A') IFWT(A)(解个二元一次方程就行了QAQ),即 a = ( a 0 + a 1 2 , a 0 − a 1 2 ) a=(\frac{a_0+a_1}{2},\frac{a_0-a_1}{2}) a=(2a0+a1,2a0a1)
然后再从底到顶循环就行了。

代码长这样
void fwt(int *a){
	int j, R, k, mid, x, y;
	for(mid = 1; mid < n; mid <<= 1){
		for(R = mid << 1, j = 0; j < n; j += R){
			for(k = 0; k < mid; k++){
				x = a[j + k], y = a[j + k + mid];
				a[j + k] = (x + y) % mod;
				a[j + k + mid] = (x - y) % mod;
			}
		}
	}
}

void ifwt(int *a){
	int j, R, k, mid, x, y;
	for(mid = 1; mid < n; mid <<= 1){
		for(R = mid << 1, j = 0; j < n; j += R){
			for(k = 0; k < mid; k++){
				x = a[j + k], y = a[j + k + mid];
				a[j + k] = z * (x + y) * inv2 % mod;
				a[j + k + mid] = z * (x - y) * inv2 % mod;
			}
		}
	}
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

子集卷积

现在要求 c i = ∑ j ∪ k = i , j ∩ k = ∅ a j × b k c_i=\sum\limits_{j\cup k=i,j\cap k=\empty}a_j\times b_k ci=jk=i,jk=aj×bk

j ∪ k = i j\cup k=i jk=i,直接或卷积即可。
j ∩ k = i j\cap k = i jk=i,在满足第一个条件时,等价于 ∣ j ∣ + ∣ k ∣ = ∣ i ∣ |j|+|k|=|i| j+k=i
因此,我们在或卷积的时候多开一维,具体的,令 f i , s f_{i,s} fi,s 表示集合 s s s 的所有大小为 i i i 的子集和。
不难证明 c i , s = ∑ j = 0 i a j , s × b i − j , s c_{i,s}=\sum\limits_{j=0}^{i}a_{j,s}\times b_{i-j,s} ci,s=j=0iaj,s×bij,s
然后我们最终要求的就是 c ∣ i ∣ , i c_{|i|,i} ci,i
复杂度是 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)

代码如下

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
LL z = 1;
int read(){
	int x, f = 1;
	char ch;
	while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
	x = ch - '0';
	while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
	return x * f;
}

const int N = (1 << 21), mod = 1e9 + 9;

int n, a[21][N], b[21][N], c[21][N], sz[N];

void fwt(int *a, int type){
	int j, k, mid, R;
	for(mid = 1; mid < n; mid <<= 1){
		for(R = mid << 1, j = 0; j < n; j += R){
			for(k = 0; k < mid; k++){
				a[j + k + mid] = (a[j + k + mid] + a[j + k] * type) % mod;
			}
		}
	}
}


int main(){
	int i, j, k, m;
	for(i = 1; i < (1 << 21); i++) sz[i] = sz[i - (i & -i)] + 1;
	m = read(); n = (1 << m);
	for(i = 0; i < n; i++) a[sz[i]][i] = read();
	for(i = 0; i < n; i++) b[sz[i]][i] = read();
	for(i = 0; i <= m; i++){
		fwt(a[i], 1); 
		fwt(b[i], 1);
	}
	for(i = 0; i <= m; i++){
		for(j = 0; j <= i; j++){
			for(k = 0; k < n; k++)
				c[i][k] = (c[i][k] + z * a[j][k] * b[i - j][k] % mod) % mod;
		}
	}
	
	for(i = 0; i <= m; i++) fwt(c[i], -1);
	for(i = 0; i < n; i++){
		printf("%d ", (c[sz[i]][i] + mod) % mod);
	}
	return 0;
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53