多模态之ALBEF—先对齐后融合,利用动量蒸馏学习视觉语言模型表征,学习细节理解与论文详细阅读:Align before Fuse-2.Model

时间:2024-04-17 17:56:41

在这里插入图片描述

2.1.模型架构

ALBEF模型架构组成:

ALBEF 包含一个图像编码器、一个文本编码器和一个多模态编码器

1.图像编码器:

使用12 层视觉变换器ViT-B/16作为图像编码器,并使用在ImageNet-1k上预先训练好的权重对其进行初始化。输入图像 I I I 被编码为一系列嵌入式数据: { v c l s , v 1 , . . . , v N } \{v_{cls}, v_1, ..., v_N\} {vcls,v1,...,vN},其中 v c l s v_{cls} vcls是 [CLS] 标记的嵌入。

2.文本编码器:

文本编码器使用 BERTbase模型的前 6 层进行初始化,。文本编码器将输入文本 T T T 转换为嵌入序列 { w c l s , w 1 , . . . , w N } \{w_{cls}, w_1, ..., w_N \} {wcls,w1,...,wN},并将其输入多模态编码器。

3.多模态编码器:

多模态编码器使用BERTbase 的后 6 层进行初始化,在多模态编码器的每一层,通过交叉关注将图像特征与文本特征融合

2.2.预训练目标

ALBEF使用三种预训练目标:

在单模态编码器上进行图像-文本对比学习(ITC),在多模态编码器上进行屏蔽语言建模(MLM)和图像-文本匹配(ITM)

1.图像-文本对比学习(ITC):

ITC目的是:在融合之前学习更好的单模态表征,即融合前先对齐
它学习一个相似度函数 s = g v ( v c l s ) T g w ( w c l s ) s = g_v(v_{cls})^Tg_w(w_{cls}) s=gv(vcls)Tgw(wcls)。其中, g v g_v gv g w g_w gw是将[CLS]嵌入映射到归一化的低维(256维)表示的线性变换。

ALBEF采用MoCo的做法,即维护两个队列,用于存储来自动量单模态编码器的最近的M个图像-文本表示。动量编码器的归一化特征被表示为 g v ′ ( v c l s ′ ) g'_v(v'_{cls}) gv(vcls) g w ′ ( w c l s ′ ) g'_w(w'_{cls}) gw(wcls)

将相似度分数定义为 s ( I , T ) = g v ( v c l s ) ⊤ g w ′ ( w c l s ′ ) s(I,T) = g_v(v_{cls})^\top g'_w(w'_{cls}) s(I,T)=gv(vcls)gw(wcls) s ( T , I ) = g w ( w c l s ) ⊤ g v ′ ( v c l s ′ ) s(T,I) = g_w(w_{cls})^\top g'_v(v'_{cls}) s(T,I)=gw(wcls)gv(vcls)

对于每幅图像和每段文字,计算图像到文字的 softmax 归一化相似度和文字到图像的 softmax 归一化相似度:

p m i 2 t ( I ) = exp ⁡ ( s ( I , T m ) / τ ) ∑ m = 1 M exp ⁡ ( s ( I , T m ) / τ ) p m t 2 i ( T ) = exp ⁡ ( s ( T , I m ) / τ ) ∑ m = 1 M exp ⁡ ( s ( T , I m ) / τ ) \begin{equation} \begin{split} p_m^\mathrm{i2t}(I) = \frac{\exp (s(I,T_m) / \tau)}{\sum_{m=1}^M \exp (s(I,T_m)/ \tau)} \\ p_m^\mathrm{t2i}(T) = \frac{\exp (s(T,I_m)/ \tau)}{\sum_{m=1}^M \exp (s(T,I_m)/ \tau)} \end{split} \end{equation} pmi2t(I)=m=1Mexp(s(I,Tm)/τ)exp(s(I,Tm)/τ)pmt2i(T)=m=1Mexp(s(T,Im)/τ)exp(s(T,Im)/τ)

其中, τ \tau τ是一个可学习的温度参数。让 y i 2 t ( I ) {y}^\mathrm{i2t}(I) yi2t(I) y t 2 i ( T ) {y}^\mathrm{t2i}(T) yt2i(T) 表示真实one-hot相似度。图像-文本对比度损失定义为 p 和 y 之间的交叉熵 H:

L i t c = 1 2 E ( I , T ) ∼ D [ H ( y i 2 t ( I ) , p i 2 t ( I ) ) + H ( y t 2 i ( T ) , p t 2 i ( T ) ) ] \begin{equation} \begin{split} \mathcal{L}_\mathrm{itc} = \frac{1}{2} \mathbb{E}_{(I,T)\sim D} \big[ \mathrm{H}({y}^\mathrm{i2t}(I),{p}^\mathrm{i2t}(I)) + \mathrm{H}({y}^\mathrm{t2i}(T),{p}^\mathrm{t2i}(T)) \big] \end{split} \end{equation} Litc=21E(I,T)D[H(yi2t(I),pi2t(I))+H(yt2i(T),pt2i(T))]

2.MASK语言建模(MLM):

MLM目的是:利用图像和上下文文本来预测屏蔽词
以 15% 的概率随机Mask掉输入标记,并用特殊标记 [MASK]取而代之。让 T ^ \hat{T} T^ 表示Mask文本, p msk ( I , T ^ ) {p}^\textrm{msk}(I,\hat{T}) pmsk(I,T^)表示模型对Mask标记的预测概率。MLM 将交叉熵损失最小化:

L m l m = E ( I , T ^ ) ∼ D H ( y msk , p msk ( I , T ^ ) ) \begin{equation} \mathcal{L}_\mathrm{mlm} = \mathbb{E}_{(I,\hat{T})\sim D} \mathrm{H} ({y}^\textrm{msk}, {p}^\textrm{msk}(I,\hat{T})) \end{equation} Lmlm=E(I,T^)DH(ymsk,pmsk(I,T^))

3.图像-文本匹配(ITM):

ITM目的是:预测一对图像和文本是正样本(匹配)还是负样本(不匹配)
将多模态编码器对 [CLS] 标记的输出嵌入后面附加一个全连接(FC)层,然后使用 softmax 预测两类概率 p i t m p^\mathrm{itm} pitm。ITM 损失为
L i t m = E ( I , T ) ∼ D H ( y itm , p itm ( I , T ) ) \begin{equation} \mathcal{L}_\mathrm{itm} = \mathbb{E}_{(I,T)\sim D} \mathrm{H} ({y}^\textrm{itm},{p}^\textrm{itm}(I,T)) \end{equation} Litm=E(I,T)DH(yitm,pitm(I,T))

ALBEF对ITM设计的训练策略:对比硬负例挖掘:

硬负例:如果一个负面的图像-文本对在语义上相似但在细节上有所不同,那么它就是一个硬负例。
ALBEF使用等式1中的对比相似度来找到一个Batch内的硬负例。对于每个小Batch中的图像,从同一Batch的文本中采样一个负面文本,遵循对比相似度的分布,其中与图像更相似的文本被采样的概率更高。同样地,我们也为每个文本采样一个硬负例图像。具体的流程:

  1. 计算图像与文本的相似度:对于每个图像,计算其与批次中所有文本的相似度
  2. 采样负面文本:对于每个图像,根据与文本的相似度分布,从批次中的文本中采样一个负面文本。采样概率与相似度呈正相关,即与图像更相似的文本被选中的概率更高。(我理解的是,抛开batch中的唯一正例,其余负例中选择相似度最高或最高的有最大的概率作为硬负例)
  3. 计算文本与图像的相似度:对于每个文本,计算其与批次中所有图像的相似度。
  4. 采样负面图像:对于每个文本,根据与图像的相似度分布,从批次中的图像中采样一个负面图像。同样地,采样概率与相似度呈正相关,即与文本更相似的图像被选中的概率更高。
  1. ALBEF最终的训练损失:

ALBEF 的全部预训练目标是:

L = L i t c + L m l m + L i t m \begin{equation} \mathcal{L} = \mathcal{L}_\mathrm{itc} + \mathcal{L}_\mathrm{mlm} + \mathcal{L}_\mathrm{itm} \end{equation} L=Litc+Lmlm+Litm

2.3.动量蒸馏

网络中的噪声数据对预训练任务带来的影响:

用于预训练的图像-文本配对大多是从网络上收集的,它们往往是有噪声的。
正图文对是弱相关的文本可能包含与图像无关的词语,或者图像可能包含文本中没有描述的实体
对于ITC影响图像的负文本也可能与图像内容相匹配
对于MLM影响可能存在与注释不同的其他词语,它们对图像的描述更好

但是,ITC 和 MLM 的单点标签会对所有负面预测进行惩罚,无论其正确与否

为了解决噪声问题,提出的动量蒸馏训练策略:

动量模型是一个持续演化的教师,由单模态和多模态编码器的指数移动平均版本组成

在训练过程中,对基础模型进行训练,使其预测结果与动量模型的预测结果相匹配。具体来说:

对于 ITC:首先使用动量单模态编码器的特征计算图像-文本相似度,即 s ′ ( I , T ) = g v ′ ( v c l s ′ ) ⊤ g w ′ ( w c l s ′ ) s'(I,T) = g'_v({v}'_{cls})^\top g'_w({w}'_{cls}) s(I,T)=gv(vcls)gw(wcls) s ′ ( T , I ) = g w ′ ( w c l s ) ⊤ g v ′ ( v c l s ′ ) s'(T,I) = g'_w({w}_{cls})^\top g'_v({v}'_{cls}) s(T,I)=gw(wcls)gv(vcls)。然后,用等式 1 中的 s ′ s' s 代替 s s s ,计算软伪目标 q i 2 t {q}^{i2t} qi2t q t 2 i {q}^{t2i} qt2i。ITC M o D _\mathrm{MoD} MoD 损失定义如下

L i t c m o d = ( 1 − α ) L i t c + α 2 E ( I , T ) ∼ D [ K L ( q i 2 t ( I ) ∥ p i 2 t ( I ) ) + K L ( q t 2 i ( T ) ∥ p t 2 i ( T ) ) ] \begin{equation} \mathcal{L}_\mathrm{itc}^\mathrm{mod} = (1-\alpha) \mathcal{L}_\mathrm{itc} + \frac{\alpha }{2} \mathbb{E}_{(I,T)\sim D} \big[ \mathrm{KL}({q}^\mathrm{i2t}(I) \parallel {p}^\mathrm{i2t}(I)) + \mathrm{KL}({q}^\mathrm{t2i}(T)\parallel {p}^\mathrm{t2i}(T))\big] \end{equation} Litcmod=(1α)Litc+2αE(I,T)D[KL(qi2t(I)pi2t(I))+KL(qt2i(T)pt2i(T))]

对于 MLM:让 q m s k ( I , T ^ ) {q}^{msk}(I,\hat{T}) qmsk(I,T^)表示动量模型对mask标记的预测概率,MLM M o D _\mathrm{MoD} MoD 损失为

L m l m m o d = ( 1 − α ) L m l m + α E ( I , T ^ ) ∼ D K L ( q msk ( I , T ^ ) ∥ p msk ( I , T ^ ) ) \begin{equation} \mathcal{L}_\mathrm{mlm}^\mathrm{mod} = (1-\alpha) \mathcal{L}_\mathrm{mlm} + \alpha \mathbb{E}_{(I,\hat{T})\sim D} \mathrm{KL} ({q}^\textrm{msk}(I,\hat{T})\parallel {p}^\textrm{msk}(I,\hat{T})) \end{equation} Lmlmmod=(1α)Lmlm+αE(I,T^)DKL(qmsk(I,T^)pmsk(I,T^))

在图中,上面一行是 ITC 的前5名伪目标,下面一行是 MLM 的前5名伪目标。可以看到对于有的图,伪目标能概括出 GT 里面没有描述出来的物体。比如左下角的图,GT 说的是 “路上的车抛锚了”,但是 top-5 的伪标签不仅能够描述这个意思,还额外地描述了 “年轻女士” 这一信息。

在这里插入图片描述

动量蒸馏总结:

动量模型其实就是另一套参数的 ALBEF(教师模型)个人理解

  1. 先训练出一版ALBEF,然后将其平均版本作为动量模型(或不用训练出来,训练一部分就开始加入动量模型)
  2. 用动量模型输出的伪目标作为额外的监督标准,即在原始损失的基础上加入模型预测与伪目标之间的 KL-发散的加权组合。
  3. 动态的更新动量模型的参数权重