AFT:Attention Free Transformer论文笔记

时间:2024-07-13 07:02:14

原文链接

2105.14103 (arxiv.org)

原文翻译

Abstract

我们介绍了 Attention Free Transformer (AFT),这是 Transformer [1] 的有效变体,它消除了点积自注意力的需要。在 AFT 层,键key和值value首先与一组学习的位置偏差position biases相结合,其结果以元素方式与查询相乘。这种新操作的内存复杂度为线性 w.r.t。上下文大小和特征维度,使其与大输入和模型大小兼容。我们还引入了 AFT-local 和 AFT-conv,这是两个模型变体,它利用了局部性和空间权重共享的思想,同时保持全局连通性。我们在两个自回归建模任务(CIFAR10 和 Enwik8)以及图像识别任务(ImageNet-1K 分类)上进行了广泛的实验。我们表明 AFT 在所有基准测试中都表现出具有竞争力的性能,同时提供了出色的效率。

1 Introduction

以Transformers[1]为代表的自注意机制推动了各种机器学习问题的发展,包括语言理解[2,3]和计算机视觉应用[4 - 6]。与卷积神经网络(cnn)或循环神经网络(rnn)等经典模型架构不同,变形金刚可以在序列中的每对元素之间进行直接交互,这使得它们在捕获长期依赖关系方面特别强大。

然而,变压器需要很高的计算成本。这一挑战的原因是需要执行具有二次时间和空间复杂性的注意力操作,这涉及上下文大小。这使得transformer难以扩展到具有大上下文大小的输入。最近的许多工作都致力于解决transformer的可伸缩性问题[7 -13]。这里的共同思想是近似全注意力操作,使用的技术包括稀疏性、局域敏感散列、低秩分解、核近似等。

在本文中,我们提出了一个不使用或近似标准点积注意力的计算模块。因此,我们将我们的模型命名为不使用注意力的Transformer (AFT)。与点积注意力类似,AFT 由查询、键和值 (Q, K, V) 三个量的交互组成。不同之处在于,在 AFT 中,键和值(上下文)首先与一组可学习的位置偏执相结合然后使用元素乘法将查询与缩减的上下文相结合。有关说明,请参见图 2。

AFT 保留了在上下文中任意两个点之间的直接交互,这是点积注意力的主要优势。事实上,AFT 可以解释为执行注意力,其中注意力头的数量与模型特征维度相同,而注意力图不需要显式计算(详见第 3.1 节)。这导致内存复杂度线性 w.r.t。输入和模型大小。

Q、K、V 的重新排列计算排序在最近的“线性化注意力”工作中也被发现 [11, 13 –15]。不同之处在于 AFT 以元素方式组合 k 和 v,而所有线性注意力论文都依赖于矩阵点积。后一种方法导致复杂度与模型特征维度的二次方,这对大型模型大小不友好。有关 AFT 与其他变体相比的复杂性分析,请参见表 1。

根据经验,我们观察到经过训练的 Transformer 往往表现出广泛的局部模式(见图 1)。这促使我们提出了两种 AFT 变体:AFT-local 和 AFT-conv。在 AFT-local 中,学习到的位置偏差被限制在局部区域,同时保持全局连接。AFT-conv 通过施加空间权重共享进一步扩展了这种设计,有效地使其成为具有全局感受野的 CNN 变体。我们表明,局部性约束不仅提供了更好的参数计算效率,而且大大提高了模型在所有任务中的表现。

我们在图像自回归建模、字符级语言建模和图像分类任务上使用 AFT 进行了实验。我们表明,AFT 提供了具有竞争力的性能,通常匹配或击败标准 Transformer 和其他变体(的准确度),同时提供了出色的效率。我们还对 AFT 的几种设计选择进行了广泛的消融研究,并讨论了它的独特属性,例如与 Transformer的兼容性、稀疏性和输入大小的可变性。

2 Multi-Head Attention

Transformers 的核心是多头注意力 (MHA) 操作。在自注意模式下,给定一个输入序列 X ∈ R^T ×d 和头部的数量 h,MHA 对每个头部 i 执行缩放的点积注意力,定义为:

其中 W Q i ∈ R^d×dk , W K i ∈ R^d×dk , W V i ∈ R^d×dv 是头部 i 的线性变换,σ 是默认设置为 sof tmax 函数的非线性(应用于矩阵的每一行)。dk, dv 分别是键和值的维度。MHA 将 h 个注意力头的输出沿通道维度拼接起来,得到特征维度 hdv。除非另有说明,我们假设dk=dv和h=d/dk。这意味着查询、键和值在每个头内都是相同的维度,输出维度与输入的维度匹配。

3 Methodology

3.1 Attention Free Transformer

我们现在定义 Attention free Transformer (AFT),它是 MHA 的插件替换,而不需要更改 Transformer 的其他架构方面。给定输入 X,AFT 首先将它们线性变换为 Q = XW^Q, K=XW^K,V =XW^V ,然后进行以下操作 2:

其中 是元素乘积; σq 是应用于query的非线性,默认为 sigmoid; w ∈ RT ×T 是学习的成对位置偏差(参见图 2 的说明)。

简而言之,对于每个目标位置t, AFT执行value的加权平均值,其结果与query进行元素间乘法相结合。具体来说,相结合的权重只是由键和一组学习得到的成对位置偏差组成。这提供了不需要计算和存储昂贵的注意力矩阵的直接优势,同时像MHA那样维护查询和值之间的全局交互。为了进一步了解AFT与MHA的关系,我们可以将方程2改写为:

这里我们使用上标 i 来索引矩阵的特征维度; <·, · >; 表示向量的点积。在这个重新排列的形式中,我们能够再次根据注意力来表达 AFT。具体来说,对于每个位置,我们对每个维度都有一个注意力向量 ai t ∈ RT,由 Q、K、w 组成。换句话说,AFT 可以解释为执行隐式注意力,头部数量与特征维度一样多,其中注意力矩阵采用分解形式。

下略