Transformer 中 Self-Attention 的二次方复杂度问题及改进方法
随着大型语言模型(LLM)输入序列长度的增加,Transformer 结构中的核心模块——自注意力机制(Self-Attention) 的计算复杂度和内存消耗都呈现二次方增长。这不仅限制了模型处理长序列的能力,也成为训练和推理阶段的重要瓶颈。
本篇博客将详细解释 Transformer 中 Self-Attention 机制的二次方复杂度来源,结合代码示例展示这一问题,并介绍一些常见的改进方法。
1. Self-Attention 机制简介
原理与公式
在自注意力(Self-Attention)机制中,输入序列 ( X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d ) 被映射到三个向量:查询(Query) ( Q Q Q )、键(Key) ( K K K ) 和 值(Value) ( V V V ),三者通过权重矩阵 ( W Q W_Q WQ )、( W K W_K WK )、( W V W_V WV ) 得到:
Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ,K=XWK,V=XWV
自注意力输出的计算公式为:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dkQKT)V
- ( n n n ) 是输入序列的长度(token 数量)。
- ( d d d ) 是输入特征的维度。
- ( d k d_k dk ) 是键向量的维度(通常 ( d k = d / h d_k = d / h dk=d/h ),其中 ( h h h ) 是多头注意力的头数)。
时间复杂度分析
从公式可以看出,自注意力机制中的关键操作是:
-
( Q K T Q K^T QKT ):查询向量 ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} Q∈Rn×dk ) 与键向量 ( K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} K∈Rn×dk ) 相乘,得到 ( n × n n \times n n×n ) 的注意力分数矩阵。
- 计算复杂度为 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) )。
-
softmax 操作:在 ( n × n n \times n n×n ) 的注意力矩阵上进行归一化,复杂度为 ( O ( n 2 ) O(n^2) O(n2) )。
-
注意力分数与 ( V V V ) 相乘:将 ( n × n n \times n n×n ) 的注意力分数矩阵与 ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} V∈Rn×dv ) 相乘,复杂度为 ( O ( n 2 d v ) O(n^2 d_v) O(n2dv) )。
综上,自注意力机制的时间复杂度为:
O ( n 2 d k + n 2 + n 2 d v ) ≈ O ( n 2 d ) O(n^2 d_k + n^2 + n^2 d_v) \approx O(n^2 d) O(n2dk+n2+n2dv)≈O(n2d)
- 当 ( d d d ) 是常数时,复杂度主要取决于输入序列的长度 ( n n n ),即呈二次方增长。
空间复杂度分析
自注意力的注意力分数矩阵 ( Q K T Q K^T QKT ) 具有 ( n × n n \times n n×n ) 的大小,需要 ( O ( n 2 ) O(n^2) O(n2) ) 的内存进行存储。
2. 代码示例:计算复杂度与空间消耗
以下代码展示了输入序列长度增加时,自注意力机制的时间和空间消耗情况:
import torch
import time
# 定义自注意力机制
def self_attention(Q, K, V):
attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))
attention_weights = torch.softmax(attention_scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output
# 测试输入序列长度不同的时间复杂度
def test_attention_complexity():
d_k = 64 # 特征维度
for n in [128, 256, 512, 1024, 2048]: # 输入序列长度
Q = torch.randn((1, n, d_k)) # Query
K = torch.randn((1, n, d_k)) # Key
V = torch.randn((1, n, d_k)) # Value
start_time = time.time()
output = self_attention(Q, K, V)
end_time = time.time()
print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")
if __name__ == "__main__":
test_attention_complexity()
运行结果示例
Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])
从结果可以看出,随着序列长度的增加,计算时间呈现明显的二次方增长。
3. 二次方复杂度的改进方法
为了减少自注意力机制的计算复杂度,许多研究者提出了优化方案,主要包括:
1. 低秩近似方法
利用低秩矩阵分解减少 ( Q K T Q K^T QKT ) 的计算复杂度,例如:
- Linformer:将 ( n × n n \times n n×n ) 的注意力矩阵通过低秩分解近似为 ( n × k n \times k n×k )(其中 ( k ≪ n k \ll n k≪n )),复杂度降为 ( O ( n k ) O(nk) O(nk) )。
2. 稀疏注意力(Sparse Attention)
- Longformer 和 BigBird:通过引入局部窗口和全局注意力机制,仅计算部分注意力分数,避免完整的 ( Q K T Q K^T QKT ) 计算,将复杂度降低为 ( O ( n log n ) O(n \log n) O(nlogn) ) 或 ( O ( n ) O(n) O(n) )。
3. 线性注意力(Linear Attention)
- Performer:使用核技巧将自注意力计算转化为线性操作,复杂度降为 ( O ( n d ) O(n d) O(nd) )。
4. 分块方法(Blockwise Attention)
将输入序列分成多个块,仅在块内或块间进行注意力计算,适用于长序列任务。
4. 总结
在 Transformer 的自注意力机制中,由于需要计算 ( Q K T Q K^T QKT ) 和存储 ( n × n n \times n n×n ) 的注意力矩阵,其时间和空间复杂度均为 (