论文笔记CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX

时间:2022-12-21 17:54:43

Gumbel-Softmax分布

Gumbel-Softmax分布是一个定义在单纯形(simplex)上的连续分布。
Gumbel-Softmax分布可以近似categorical分布。

Gumbel-Max trick

z z z表示为服从 π = ( π 1 , … , π k ) \pi = (\pi_1,\ldots,\pi_k) π=(π1,,πk)的categorical随机变量。categorical分布的样本表示为 k k k维的one-hot向量,在 k − 1 k-1 k1维的单纯形空间 △ k − 1 \bigtriangleup^{k-1} k1中。

很多情况下需要对categorical分布采样,并优化categorical分布的参数。但是采样的参数没法直接优化,所以需要重参数技巧(reparametrization tricks)。针对categorical分布的重参数技巧是Gumbel-Max trick。

具体而言,Gumbel-Max trick是重参数技巧的一个特例,其提供了一个简单有效的从categorical分布采样的方法:
z = one-hot ( argmax ⁡ i [ g i + log ⁡ π i ] ) (1) z = \text{one-hot}\left(\operatorname{argmax}_i[g_i + \log \pi_i]\right) \tag{1} z=one-hot(argmaxi[gi+logπi])(1)其中 g i ∼ G u m b e l ( 0 , 1 ) g_i \sim Gumbel(0,1) giGumbel(0,1)Gumbel分布用于对各种分布的多个样本的最大值(或最小值)的分布进行建模。Gumbel分布的概率密度是:
G u m b e l ( μ , β ) = 1 β exp ⁡ ( − x − μ β + exp ⁡ ( − x − μ β ) ) Gumbel(\mu, \beta) = \frac{1}{\beta}\exp(-\frac{x - \mu}{\beta} + \exp(-\frac{x - \mu}{\beta})) Gumbel(μ,β)=β1exp(βxμ+exp(βxμ))

Gumbel-Softmax分布

上面公式(1)中的argmax是不可导的,用可导的softmax函数去近似其中的argmax,于是得到样本 y ∈ △ k − 1 y\in\bigtriangleup^{k-1} yk1
y i = softmax ⁡ [ g i + log ⁡ π i ] = exp ⁡ ( log ⁡ π i + g i τ ) ∑ j = 1 k exp ⁡ ( log ⁡ π j + g j τ ) y_i = \operatorname{softmax}[g_i + \log \pi_i] = \frac{\exp(\frac{\log \pi_i + g_i}{\tau})}{\sum_{j = 1}^k \exp(\frac{\log \pi_j + g_j}{\tau})} yi=softmax[gi+logπi]=j=1kexp(τlogπj+gj)exp(τlogπi+gi) y i y_i yi服从Gumbel-Softmax分布。Gumbel-Softmax分布的概率密度函数是:
论文笔记CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX
随着 τ \tau τ趋近于0,Gumbel-Softmax分布的样本逐渐变成one-hot的,Gumbel-Softmax分布也逐渐变成了categorical分布。如下图所示:

论文笔记CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX

Gumbel-Softmax Estimator

Gumbel-Softmax分布的 ∂ y ∂ π \frac{\partial y}{\partial \pi} πy是有定义的。
通过用Gumbel-Softmax样本替换categorical样本,我们可以使用反向传播来计算梯度。
把在训练阶段,用可导的Gumbel-Softmax样本替代不可导的categorical样本的过程称为Gumbel-Softmax Estimator。

在温度 τ \tau τ小时,样本接近one-hot但梯度方差大,在温度 τ \tau τ大时,样本平滑但梯度方差小。
实际中,我们从高温 τ \tau τ开始,然后退火到一个很小但非零的温度。

Straight-Through (ST) Gumbel-Softmax Estimator

Straight-Through Estimator (STE)

首先介绍下Straight-Through Estimator (STE)。
STE是量化(quantization)中常见的求导方式。
比如有sign函数:
w b = sign ⁡ ( w ) = { + 1 ,  if  w ≥ 0 − 1 ,  otherwise  w_{b}=\operatorname{sign}(w)=\left\{ \begin{array}{ll}{+1,}{\text { if } w \geq 0} \\ {-1,}{\text { otherwise }}\end{array}\right. wb=sign(w)={+1, if w01, otherwise 这个sign函数在定义域范围内导数都是0。STE就是用来解决sign函数梯度无法反传的问题的。
二值网络训练过程可以是这样:模型中每个参数其实都是一个浮点型的数,每次迭代其实都是在更新这个浮点型数。但是,在前向传播的过程中,先用sign函数对浮点型参数二值化处理然后再参与到运算,而此时并没有把这个浮点型数值抛弃掉,而是暂时在内存中保存起来。前向传播完之后,网络得到一个输出,就可以接着通过反向传播算出二值参数的梯度,再直接用这个梯度来更新对应的浮点型参数。这样,前向反向就跑通了。等训练的差不多了,就最后对模型的这些浮点型参数做一次二值化处理形成最终的二值网络,此时浮点型的参数就完成了任务,可以被抛弃掉了。

Straight-Through (ST) Gumbel-Softmax Estimator

在前向的时候使用argmax离散化 y y y,但在梯度反传的时候,使用连续近似 ∇ θ z ≈ ∇ θ y \nabla_\theta z \approx \nabla_\theta y θzθy

参考

ICLR 2017 Categorical Reparameterization with Gumbel-Softmax
Emma Benjaminson blog
二值网络,围绕STE的那些事儿
gumbel-max-trick的数学证明