2.1.模型架构
ALBEF模型架构组成:
ALBEF 包含一个图像编码器、一个文本编码器和一个多模态编码器。
1.图像编码器:
使用12 层视觉变换器ViT-B/16作为图像编码器,并使用在ImageNet-1k上预先训练好的权重对其进行初始化
。输入图像
I
I
I 被编码为一系列嵌入式数据:
{
v
c
l
s
,
v
1
,
.
.
.
,
v
N
}
\{v_{cls}, v_1, ..., v_N\}
{vcls,v1,...,vN},其中
v
c
l
s
v_{cls}
vcls是 [CLS] 标记的嵌入。
2.文本编码器:
文本编码器使用 BERTbase模型的前 6 层进行初始化,
。文本编码器将输入文本
T
T
T 转换为嵌入序列
{
w
c
l
s
,
w
1
,
.
.
.
,
w
N
}
\{w_{cls}, w_1, ..., w_N \}
{wcls,w1,...,wN},并将其输入多模态编码器。
3.多模态编码器:
多模态编码器使用BERTbase 的后 6 层进行初始化,在多模态编码器的每一层,通过交叉关注将图像特征与文本特征融合
2.2.预训练目标
ALBEF使用三种预训练目标:
在单模态编码器上进行图像-文本对比学习(ITC),在多模态编码器上进行屏蔽语言建模(MLM)和图像-文本匹配(ITM)
1.图像-文本对比学习(ITC):
ITC目的是:在融合之前学习更好的单模态表征,即融合前先对齐。
它学习一个相似度函数
s
=
g
v
(
v
c
l
s
)
T
g
w
(
w
c
l
s
)
s = g_v(v_{cls})^Tg_w(w_{cls})
s=gv(vcls)Tgw(wcls)。其中,
g
v
g_v
gv和
g
w
g_w
gw是将[CLS]嵌入映射到归一化的低维(256维)表示的线性变换。
ALBEF采用MoCo的做法,即维护两个队列,用于存储来自动量单模态编码器的最近的M个图像-文本表示。动量编码器的归一化特征被表示为 g v ′ ( v c l s ′ ) g'_v(v'_{cls}) gv′(vcls′)和 g w ′ ( w c l s ′ ) g'_w(w'_{cls}) gw′(wcls′)。
将相似度分数定义为 s ( I , T ) = g v ( v c l s ) ⊤ g w ′ ( w c l s ′ ) s(I,T) = g_v(v_{cls})^\top g'_w(w'_{cls}) s(I,T)=gv(vcls)⊤gw′(wcls′) 和 s ( T , I ) = g w ( w c l s ) ⊤ g v ′ ( v c l s ′ ) s(T,I) = g_w(w_{cls})^\top g'_v(v'_{cls}) s(T,I)=gw(wcls)⊤gv′(vcls′)。
对于每幅图像和每段文字,计算图像到文字的 softmax 归一化相似度和文字到图像的 softmax 归一化相似度:
p m i 2 t ( I ) = exp ( s ( I , T m ) / τ ) ∑ m = 1 M exp ( s ( I , T m ) / τ ) p m t 2 i ( T ) = exp ( s ( T , I m ) / τ ) ∑ m = 1 M exp ( s ( T , I m ) / τ ) \begin{equation} \begin{split} p_m^\mathrm{i2t}(I) = \frac{\exp (s(I,T_m) / \tau)}{\sum_{m=1}^M \exp (s(I,T_m)/ \tau)} \\ p_m^\mathrm{t2i}(T) = \frac{\exp (s(T,I_m)/ \tau)}{\sum_{m=1}^M \exp (s(T,I_m)/ \tau)} \end{split} \end{equation} pmi2t(I)=∑m=1Mexp(s(I,Tm)/τ)exp(s(I,Tm)/τ)pmt2i(T)=∑m=1Mexp(s(T,Im)/τ)exp(s(T,Im)/τ)
其中, τ \tau τ是一个可学习的温度参数。让 y i 2 t ( I ) {y}^\mathrm{i2t}(I) yi2t(I) 和 y t 2 i ( T ) {y}^\mathrm{t2i}(T) yt2i(T) 表示真实one-hot相似度。图像-文本对比度损失定义为 p 和 y 之间的交叉熵 H:
L i t c = 1 2 E ( I , T ) ∼ D [ H ( y i 2 t ( I ) , p i 2 t ( I ) ) + H ( y t 2 i ( T ) , p t 2 i ( T ) ) ] \begin{equation} \begin{split} \mathcal{L}_\mathrm{itc} = \frac{1}{2} \mathbb{E}_{(I,T)\sim D} \big[ \mathrm{H}({y}^\mathrm{i2t}(I),{p}^\mathrm{i2t}(I)) + \mathrm{H}({y}^\mathrm{t2i}(T),{p}^\mathrm{t2i}(T)) \big] \end{split} \end{equation} Litc=21E(I,T)∼D[H(yi2t(I),pi2t(I))+H(yt2i(T),pt2i(T))]
2.MASK语言建模(MLM):
MLM目的是:利用图像和上下文文本来预测屏蔽词。
以 15% 的概率随机Mask掉输入标记,并用特殊标记 [MASK]取而代之。让
T
^
\hat{T}
T^ 表示Mask文本,
p
msk
(
I
,
T
^
)
{p}^\textrm{msk}(I,\hat{T})
pmsk(I,T^)表示模型对Mask标记的预测概率。MLM 将交叉熵损失最小化:
L m l m = E ( I , T ^ ) ∼ D H ( y msk , p msk ( I , T ^ ) ) \begin{equation} \mathcal{L}_\mathrm{mlm} = \mathbb{E}_{(I,\hat{T})\sim D} \mathrm{H} ({y}^\textrm{msk}, {p}^\textrm{msk}(I,\hat{T})) \end{equation} Lmlm=E(I,T^)∼DH(ymsk,pmsk(I,T^))
3.图像-文本匹配(ITM):
ITM目的是:预测一对图像和文本是正样本(匹配)还是负样本(不匹配)。
将多模态编码器对 [CLS] 标记的输出嵌入后面附加一个全连接(FC)层,然后使用 softmax 预测两类概率
p
i
t
m
p^\mathrm{itm}
pitm。ITM 损失为
L
i
t
m
=
E
(
I
,
T
)
∼
D
H
(
y
itm
,
p
itm
(
I
,
T
)
)
\begin{equation} \mathcal{L}_\mathrm{itm} = \mathbb{E}_{(I,T)\sim D} \mathrm{H} ({y}^\textrm{itm},{p}^\textrm{itm}(I,T)) \end{equation}
Litm=E(I,T)∼DH(yitm,pitm(I,T))
ALBEF对ITM设计的训练策略:对比硬负例挖掘:
硬负例:如果一个负面的图像-文本对在语义上相似但在细节上有所不同,那么它就是一个硬负例。
ALBEF使用等式1中的对比相似度来找到一个Batch内的硬负例。对于每个小Batch中的图像,从同一Batch的文本中采样一个负面文本,遵循对比相似度的分布,其中与图像更相似的文本被采样的概率更高。同样地,我们也为每个文本采样一个硬负例图像。具体的流程:
- 计算图像与文本的相似度:对于每个图像,计算其与批次中所有文本的相似度
- 采样负面文本:对于每个图像,根据与文本的相似度分布,从批次中的文本中采样一个负面文本。采样概率与相似度呈正相关,即与图像更相似的文本被选中的概率更高。(我理解的是,抛开batch中的唯一正例,其余负例中选择相似度最高或最高的有最大的概率作为硬负例)
- 计算文本与图像的相似度:对于每个文本,计算其与批次中所有图像的相似度。
- 采样负面图像:对于每个文本,根据与图像的相似度分布,从批次中的图像中采样一个负面图像。同样地,采样概率与相似度呈正相关,即与文本更相似的图像被选中的概率更高。
- ALBEF最终的训练损失:
ALBEF 的全部预训练目标是:
L = L i t c + L m l m + L i t m \begin{equation} \mathcal{L} = \mathcal{L}_\mathrm{itc} + \mathcal{L}_\mathrm{mlm} + \mathcal{L}_\mathrm{itm} \end{equation} L=Litc+Lmlm+Litm
2.3.动量蒸馏
网络中的噪声数据对预训练任务带来的影响:
用于预训练的图像-文本配对大多是从网络上收集的,它们往往是有噪声的。
正图文对是弱相关的:文本可能包含与图像无关的词语,或者图像可能包含文本中没有描述的实体
。
对于ITC影响:图像的负文本也可能与图像内容相匹配
。
对于MLM影响:可能存在与注释不同的其他词语,它们对图像的描述更好
。
但是,ITC 和 MLM 的单点标签会对所有负面预测进行惩罚,无论其正确与否。
为了解决噪声问题,提出的动量蒸馏训练策略:
动量模型是一个持续演化的教师,由单模态和多模态编码器的指数移动平均版本组成。
在训练过程中,对基础模型进行训练,使其预测结果与动量模型的预测结果相匹配。具体来说:
对于 ITC:首先使用动量单模态编码器的特征计算图像-文本相似度,即 s ′ ( I , T ) = g v ′ ( v c l s ′ ) ⊤ g w ′ ( w c l s ′ ) s'(I,T) = g'_v({v}'_{cls})^\top g'_w({w}'_{cls}) s′(I,T)=gv′(vcls′)⊤gw′(wcls′)和 s ′ ( T , I ) = g w ′ ( w c l s ) ⊤ g v ′ ( v c l s ′ ) s'(T,I) = g'_w({w}_{cls})^\top g'_v({v}'_{cls}) s′(T,I)=gw′(wcls)⊤gv′(vcls′)。然后,用等式 1 中的 s ′ s' s′ 代替 s s s ,计算软伪目标 q i 2 t {q}^{i2t} qi2t和 q t 2 i {q}^{t2i} qt2i。ITC M o D _\mathrm{MoD} MoD 损失定义如下
L i t c m o d = ( 1 − α ) L i t c + α 2 E ( I , T ) ∼ D [ K L ( q i 2 t ( I ) ∥ p i 2 t ( I ) ) + K L ( q t 2 i ( T ) ∥ p t 2 i ( T ) ) ] \begin{equation} \mathcal{L}_\mathrm{itc}^\mathrm{mod} = (1-\alpha) \mathcal{L}_\mathrm{itc} + \frac{\alpha }{2} \mathbb{E}_{(I,T)\sim D} \big[ \mathrm{KL}({q}^\mathrm{i2t}(I) \parallel {p}^\mathrm{i2t}(I)) + \mathrm{KL}({q}^\mathrm{t2i}(T)\parallel {p}^\mathrm{t2i}(T))\big] \end{equation} Litcmod=(1−α)Litc+2αE(I,T)∼D[KL(qi2t(I)∥pi2t(I))+KL(qt2i(T)∥pt2i(T))]
对于 MLM:让 q m s k ( I , T ^ ) {q}^{msk}(I,\hat{T}) qmsk(I,T^)表示动量模型对mask标记的预测概率,MLM M o D _\mathrm{MoD} MoD 损失为
L m l m m o d = ( 1 − α ) L m l m + α E ( I , T ^ ) ∼ D K L ( q msk ( I , T ^ ) ∥ p msk ( I , T ^ ) ) \begin{equation} \mathcal{L}_\mathrm{mlm}^\mathrm{mod} = (1-\alpha) \mathcal{L}_\mathrm{mlm} + \alpha \mathbb{E}_{(I,\hat{T})\sim D} \mathrm{KL} ({q}^\textrm{msk}(I,\hat{T})\parallel {p}^\textrm{msk}(I,\hat{T})) \end{equation} Lmlmmod=(1−α)Lmlm+αE(I,T^)∼DKL(qmsk(I,T^)∥pmsk(I,T^))
在图中,上面一行是 ITC 的前5名伪目标,下面一行是 MLM 的前5名伪目标。可以看到对于有的图,伪目标能概括出 GT 里面没有描述出来的物体。比如左下角的图,GT 说的是 “路上的车抛锚了”,但是 top-5 的伪标签不仅能够描述这个意思,还额外地描述了 “年轻女士” 这一信息。
动量蒸馏总结:
动量模型其实就是另一套参数的 ALBEF(教师模型)。个人理解:
- 先训练出一版ALBEF,然后将其平均版本作为动量模型(或不用训练出来,训练一部分就开始加入动量模型)
- 用动量模型输出的伪目标作为额外的监督标准,即在原始损失的基础上加入模型预测与伪目标之间的 KL-发散的加权组合。
- 动态的更新动量模型的参数权重