Generative Pretraining From Pixels
引用: Chen M, Radford A, Child R, et al. Generative pretraining from pixels[C]//International conference on machine learning. PMLR, 2020: 1691-1703.
论文链接: http://proceedings.mlr.press/v119/chen20s.html
简介
受自然语言中无监督表示学习进展的启发,作者研究了类似的模型是否能够学习图像的有用表示,训练了一个序列Transformer来自回归地预测像素,而不包含2D输入结构的知识。尽管是在低分辨率的ImageNet上进行训练,没有标签,但实验发现一个GPT-2规模的模型通过线性探测、微调和低数据分类学习,学习到了强大的图像表示。在CIFAR-10上,使用线性探测达到了96.3%的准确率,超过了监督的Wide ResNet,全微调达到了99.0%的准确率,与*监督预训练模型相匹配。同时,作者还在ImageNet上与自监督基准进行了比较,通过将像素替换为VQVAE编码,在线性探测特征时达到了69.0%的top-1准确率。
Method
论文的方法包括预训练阶段和微调阶段。在预训练中,探索了auto-regressive和BERT,还应用序列Transformer架构来预测像素,而不是语言标记。而测量表征质量的一种方法是对图像分类进行微调。微调为模型添加了一个小的分类头,用于优化分类目标并调整所有权重。当与早停结合使用时,预训练可以被视为一种有利的初始化或正则化。另一种方法则使用预先训练的模型作为特征提取器。特别地,给定标记的示例(X,Y),将模型应用于X以产生特征fx。然后,在(fx,Y)上训练线性分类器。线性探测源自一种直觉,即好的特征应该线性地分离转移任务的类别。此外,线性探测有助于将特征质量与模型架构区分开来:在微调中,一个模型可能优于另一个模型,因为它的架构更适合下游任务,而不是因为更好的预训练。
Pre-training
给定由高维数据
X
=
(
x
1
,
.
.
.
,
x
n
)
X=(x_1,...,x_n)
X=(x1,...,xn)组成的未标记数据集
X
X
X,可以选择集合
[
1
,
n
]
[1,n]
[1,n]的排列π,并对密度
p
(
x
)
p(x)
p(x)进行自回归建模:
当处理图像时,选择
1
≤
i
≤
n
1≤i≤n
1≤i≤n的单位置换
π
i
=
i
π_i=i
πi=i,也称为光栅顺序。通过最小化数据的负对数似然来训练模型:
对于BERT目标,其采样为子序列
M
⊂
[
1
,
n
]
M⊂[1,n]
M⊂[1,n],使得每个索引
i
i
i独立地具有出现在
M
M
M中的概率为0.15。称
M
M
M为BERT掩码,并且通过最小化以“未掩码”为条件的“掩码”元素
x
M
x_M
xM的负对数似然来训练模型:
Architecture
transformer decoder取一个输入序列
x
1
,
.
.
.
,
x
n
x_1,...,x_n
x1,...,xn,并为每个位置产生
d
d
d维嵌入。解码器被实现为
L
L
L个块的堆栈,其中第
l
l
l个产生中间嵌入
h
l
1
,
.
.
.
,
h
l
n
h_l^1,...,h_l^n
hl1,...,hln也是维数d。我们使用transformer decoder块的GPT-2公式,它作用于输入张量
h
l
h_l
hl如下:
特别地,**层规范在注意力机制和MLP之前,并且所有运算都位于残差路径上。**这样的配置可以轻松地缩放transformer。
序列元素之间的唯一混合发生在注意力操作中,为了确保在训练AR目标时进行适当的调节,将标准的上三角掩码应用于注意力逻辑的n×n矩阵。当使用BERT目标时,不需要注意logit掩蔽:在将内容嵌入应用于输入序列之后,将M中的位置清零。
此外,由于学习了每个序列元素的独立位置嵌入,BERT模型没有位置归纳偏差(即它是置换不变的)。换句话说,位置之间的任何空间关系都必须由模型在训练时学习。对于AR模型来说,这并不完全正确,因为选择光栅顺序也会修复预先指定的条件顺序。然而,置换不变性是与卷积神经网络形成强烈对比的一种特性,卷积神经网络包含了特征应该从空间上接近的元素产生的归纳偏差。
Fine-tuning
当进行微调时,我们对序列的 n L n^L nL维度进行平均池化,以提取每个示例的特征的d维向量。然后,学习从 f L f_L fL到类别的logits的投影,使用它来最小化交叉熵损失。
Linear Probing
为线性探测提取固定特征遵循与微调类似的过程,只是平均池化并不总是在最后一层:
其中0≤l≤l。实验表明,最佳特征通常位于网络的中间。在微调中,投影这些中间特征以产生类logits。
实验
表征质量在很大程度上取决于提取特征的层。与监督模型相比,这些生成模型的最佳表征位于网络的中间层。