Transformer 中 Self-Attention 的二次方复杂度(Quadratic Complexity )问题及改进方法:中英双语

时间:2024-12-18 11:12:42

Transformer 中 Self-Attention 的二次方复杂度问题及改进方法

随着大型语言模型(LLM)输入序列长度的增加,Transformer 结构中的核心模块——自注意力机制(Self-Attention) 的计算复杂度和内存消耗都呈现二次方增长。这不仅限制了模型处理长序列的能力,也成为训练和推理阶段的重要瓶颈。

本篇博客将详细解释 Transformer 中 Self-Attention 机制的二次方复杂度来源,结合代码示例展示这一问题,并介绍一些常见的改进方法。


1. Self-Attention 机制简介

原理与公式

在自注意力(Self-Attention)机制中,输入序列 ( X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d ) 被映射到三个向量:查询(Query) ( Q Q Q )、键(Key) ( K K K ) 和 值(Value) ( V V V ),三者通过权重矩阵 ( W Q W_Q WQ )、( W K W_K WK )、( W V W_V WV ) 得到:

Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ,K=XWK,V=XWV

自注意力输出的计算公式为:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V

  • ( n n n ) 是输入序列的长度(token 数量)。
  • ( d d d ) 是输入特征的维度。
  • ( d k d_k dk ) 是键向量的维度(通常 ( d k = d / h d_k = d / h dk=d/h ),其中 ( h h h ) 是多头注意力的头数)。

时间复杂度分析

从公式可以看出,自注意力机制中的关键操作是:

  1. ( Q K T Q K^T QKT ):查询向量 ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk ) 与键向量 ( K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} KRn×dk ) 相乘,得到 ( n × n n \times n n×n ) 的注意力分数矩阵。

    • 计算复杂度为 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) )。
  2. softmax 操作:在 ( n × n n \times n n×n ) 的注意力矩阵上进行归一化,复杂度为 ( O ( n 2 ) O(n^2) O(n2) )。

  3. 注意力分数与 ( V V V ) 相乘:将 ( n × n n \times n n×n ) 的注意力分数矩阵与 ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} VRn×dv ) 相乘,复杂度为 ( O ( n 2 d v ) O(n^2 d_v) O(n2dv) )。

综上,自注意力机制的时间复杂度为:

O ( n 2 d k + n 2 + n 2 d v ) ≈ O ( n 2 d ) O(n^2 d_k + n^2 + n^2 d_v) \approx O(n^2 d) O(n2dk+n2+n2dv)O(n2d)

  • 当 ( d d d ) 是常数时,复杂度主要取决于输入序列的长度 ( n n n ),即呈二次方增长

空间复杂度分析

自注意力的注意力分数矩阵 ( Q K T Q K^T QKT ) 具有 ( n × n n \times n n×n ) 的大小,需要 ( O ( n 2 ) O(n^2) O(n2) ) 的内存进行存储。


2. 代码示例:计算复杂度与空间消耗

以下代码展示了输入序列长度增加时,自注意力机制的时间和空间消耗情况:

import torch
import time

# 定义自注意力机制
def self_attention(Q, K, V):
    attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))
    attention_weights = torch.softmax(attention_scores, dim=-1)
    output = torch.matmul(attention_weights, V)
    return output

# 测试输入序列长度不同的时间复杂度
def test_attention_complexity():
    d_k = 64  # 特征维度
    for n in [128, 256, 512, 1024, 2048]:  # 输入序列长度
        Q = torch.randn((1, n, d_k))  # Query
        K = torch.randn((1, n, d_k))  # Key
        V = torch.randn((1, n, d_k))  # Value

        start_time = time.time()
        output = self_attention(Q, K, V)
        end_time = time.time()

        print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")

if __name__ == "__main__":
    test_attention_complexity()

运行结果示例

Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])

从结果可以看出,随着序列长度的增加,计算时间呈现明显的二次方增长。


3. 二次方复杂度的改进方法

为了减少自注意力机制的计算复杂度,许多研究者提出了优化方案,主要包括:

1. 低秩近似方法

利用低秩矩阵分解减少 ( Q K T Q K^T QKT ) 的计算复杂度,例如:

  • Linformer:将 ( n × n n \times n n×n ) 的注意力矩阵通过低秩分解近似为 ( n × k n \times k n×k )(其中 ( k ≪ n k \ll n kn )),复杂度降为 ( O ( n k ) O(nk) O(nk) )。

2. 稀疏注意力(Sparse Attention)

  • LongformerBigBird:通过引入局部窗口和全局注意力机制,仅计算部分注意力分数,避免完整的 ( Q K T Q K^T QKT ) 计算,将复杂度降低为 ( O ( n log ⁡ n ) O(n \log n) O(nlogn) ) 或 ( O ( n ) O(n) O(n) )。

3. 线性注意力(Linear Attention)

  • Performer:使用核技巧将自注意力计算转化为线性操作,复杂度降为 ( O ( n d ) O(n d) O(nd) )。

4. 分块方法(Blockwise Attention)

将输入序列分成多个块,仅在块内或块间进行注意力计算,适用于长序列任务。


4. 总结

在 Transformer 的自注意力机制中,由于需要计算 ( Q K T Q K^T QKT ) 和存储 ( n × n n \times n n×n ) 的注意力矩阵,其时间和空间复杂度均为 (