手动实现一个迷你Llama:手动实现Llama模型

时间:2025-04-06 22:31:34

  • 进阶的 LLM Llama模型教学
    • 一、库导入
    • 二、实现 ModelArgs 参数类构建
      • Transformer 模型参数解释
    • 三、实现均方根归一化(RMSNorm,LayerNorm 的一种变体)层
      • 定义与原理
      • RMSNorm 公式
      • 与 LayerNorm 的对比
      • RMSNorm 的优点
      • RMSNorm 的实现
      • RMSNorm 的关键步骤
        • 参数说明
        • `__init__` 关键操作
        • `_norm` 关键操作
        • `forward` 关键操作
      • 举一个张量作为 RMSNorm 的例子
        • 示例张量
        • RMSNorm 的计算过程
          • 1. 初始化 RMSNorm
          • 2. 计算均方根值(RMS)
          • 3. 归一化
          • 4. 应用缩放因子
    • 四、旋转位置矩阵函数实现
      • 旋转位置嵌入的基本概念
      • 代码解析
        • 1. 计算频率 freqs \text{freqs} freqs
        • 2. 生成时间序列 t t t
        • 3. 计算外积
        • 4. 计算实部和虚部
      • 输出
      • 位置编码中需要注意的点
        • 1. **嵌入向量的拆分**
        • 2. **旋转操作**
      • 3. **应用旋转**
      • 4. **合并旋转后的结果**
      • 具体的例子
      • 步骤1:计算频率
      • 步骤2:生成时间序列
      • 步骤3:计算外积
      • 步骤4:计算实部和虚部
      • 最终结果
    • 五、旋转位置嵌入
      • 传统位置编码 vs. 旋转位置嵌入
        • 传统位置编码
        • 旋转位置嵌入
      • 在自注意力机制中的应用
      • 为什么在自注意力中而不是嵌入层中
      • ROPE函数的实现
      • 1. `reshape_for_broadcast` 函数
      • 广播操作示例
        • 1.1解析
        • 1.2作用
      • 2. `apply_rotary_emb` 函数
        • 示例输入
        • 示例计算
        • 步骤2:调整频率张量的形状以进行广播
        • 2.示例计算
        • 步骤3:应用旋转操作
        • 3.示例计算
        • 步骤4:将最后两个维度合并,并还原为原始张量的形状
        • 4.示例计算
        • 4.最终结果
        • 总结
        • 2.1解析
        • 2.2作用
      • 3. `repeat_kv` 函数
        • 3.1解析
        • 3.2作用
      • 3.3总结
    • 六、attention 模块
      • 前向传播过程
      • **输入和参数回顾**
      • **前向传播过程**
        • **1. 计算查询(Q)、键(K)、值(V)**
        • **2. 调整形状**
        • **3. 应用旋转位置嵌入(RoPE)**
        • **4. 调整维度**
        • **5. 注意力计算**
        • 6.**添加因果遮蔽矩阵**
        • **应用 Softmax**
        • **添加因果遮蔽矩阵并应用 Softmax**
        • **应用 Dropout**
        • **计算加权和**
      • **6. 恢复维度并投影**
        • **恢复时间维度并合并头**
        • **最终投影回残差流**
      • **最终输出**
      • **总结**
    • 七、什么是时间维度
      • **时间维度的含义**
      • **时间维度在代码中的体现**
        • **1. 输入矩阵的形状**
        • **2. 时间维度的调整**
        • **3. 时间维度的恢复**
        • **4. 因果遮蔽**
      • 6.总结
    • 八、MLP(多层感知机,Multilayer Perceptron)
      • MLP 结构
      • 激活函数
    • DecoderLayer解码器层的实现
    • 九、transformer总架构

进阶的 LLM Llama模型教学

Llama 模型在自然语言处理领域有着广泛的应用,它通过自注意力机制能够有效地捕捉序列中的长距离依赖关系。为了更好地理解和实现这个模型,我们先从一些基础的代码和概念入手。

一、库导入

在开始之前,我们需要导入一些必要的 Python 库。这些库将帮助我们完成模型的构建和训练。

import math
import struct
import inspect
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn

这些库涵盖了数学运算、数据结构定义、类型提示以及 PyTorch 框架的相关功能,为后续的模型实现提供了强大的支持。

二、实现 ModelArgs 参数类构建

接下来,我们来定义一个参数类 ModelArgs,用于存储 Transformer 模型的各种超参数。这些超参数将决定模型的结构和行为。

@dataclass
class ModelArgs:
    # 自定义超参数
    dim: int = 288  # 模型维度
    n_layers: int = 6  # Transformer层数
    n_heads: int = 6  # 注意力机制的头数
    n_kv_heads: Optional[int] = 6  # 键/值头数,如果未指定,则默认为n_heads
    vocab_size: int = 32000  # 词汇表大小
    hidden_dim: Optional[int] = None  # 隐藏层维度,如果未指定,则使用其他规则确定
    multiple_of: int = 32  # MLP隐藏层大小是这个数的倍数
    norm_eps: float = 1e-5  # 归一化层的epsilon值
    max_seq_len: int = 256  # 最大序列长度
    dropout: float = 0.0  # 丢弃率

Transformer 模型参数解释

  • dim(模型的嵌入维度):这是每个输入词或序列元素的特征维度。它决定了模型对输入数据的表示能力。
  • n_heads(多头注意力机制中的头数):这个参数决定了嵌入维度如何被拆分以及进行并行计算。多头注意力机制能够让模型从不同的角度学习输入数据的特征。
  • n_layers(Transformer 的层数):即模型中包含的 Transformer 编码器或解码器的数量。层数越多,模型能够捕捉到的复杂关系就越多,但计算成本也会相应增加。
  • n_kv_heads(键(Key)和值(Value)的头数):在某些模型(如 LLaMA)中,键和值的头数可以与查询(Query)的头数不同,以减少计算量。这个参数提供了灵活性,使模型能够在保持性能的同时降低计算成本。
  • vocab_size(词汇表的大小):即模型可以处理的不同词或标记的数量。它决定了模型的输入范围。
  • hidden_dim(MLP 隐藏层的维度):这是多层感知机(MLP)隐藏层的维度。如果未指定,则会根据其他规则(如模型维度的倍数)动态计算。MLP 是 Transformer 中的一个重要组件,用于对输入数据进行非线性变换。
  • multiple_of(MLP 隐藏层大小的倍数):MLP 隐藏层大小必须是这个数的倍数。这通常是出于硬件优化的考虑,例如在 GPU 上进行矩阵运算时,某些维度大小为 32 的倍数可以提高计算效率。
  • norm_eps(归一化层的 epsilon 值):这是归一化层(如 LayerNorm)中的一个小常数,用于防止除零操作。在计算归一化时,它能够确保数值稳定性。
  • max_seq_len(最大序列长度):即输入序列的最大长度。这个参数限制了模型能够处理的序列长度,对于长文本处理非常重要。
  • dropout(丢弃率):这是在训练过程中,模型中某些层的输出被随机丢弃的比例。丢弃率可以防止过拟合,并提高模型的泛化能力。

三、实现均方根归一化(RMSNorm,LayerNorm 的一种变体)层

定义与原理

RMSNorm 是 LayerNorm 的一种变体,它通过计算输入向量的均方根(Root Mean Square, RMS)来进行归一化,而省略了计算均值的步骤。这种方法在某些情况下能够提高计算效率和数值稳定性。

RMSNorm 公式

对于输入向量 x = [ x 1 , x 2 , … , x H ] \mathbf{x} = [x_1, x_2, \dots, x_H] x=[x1,x2,,xH],RMSNorm 的计算步骤如下:

  1. 计算均方根值(RMS)

RMS ( x ) = 1 H ∑ i = 1 H x i 2 \text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{H} \sum_{i=1}^{H} x_i^2} RMS(x)=H1i=1Hxi2
其中, H H H是输入向量的维度。
2. 归一化

x ^ i = x i RMS ( x ) + ϵ \hat{x}_i = \frac{x_i}{\text{RMS}(\mathbf{x}) + \epsilon} x^i=RMS(x)+ϵxi
其中, ϵ \epsilon ϵ是一个极小的常数(如 1 0 − 8 10^{-8} 108),用于防止分母为零。
3. 缩放(可选)

y i = γ i ⋅ x ^ i y_i = \gamma_i \cdot \hat{x}_i yi=γix^i
其中, γ \gamma γ是可学习的缩放参数,与输入向量同维度。

将上述步骤综合起来,RMSNorm 的完整公式为:
RMSNorm ( x ) = γ ⊙ x 1 H ∑ i = 1 H x i 2 + ϵ \text{RMSNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x}}{\sqrt{\frac{1}{H} \sum_{i=1}^{H} x_i^2 + \epsilon}} RMSNorm(x)=γH1i=1Hxi2+ϵ x
其中, ⊙ \odot 表示逐元素乘法。

与 LayerNorm 的对比

为了更好地理解 RMSNorm,我们来看一下它与 LayerNorm 的区别:

  • LayerNorm 公式

x ^ i = x i − μ σ 2 + ϵ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} x^i=σ2+ϵ xiμ
其中, μ \mu μ是均值, σ 2 \sigma^2 σ2是方差。

  • RMSNorm 公式

x ^ i = x i 1 H ∑ i = 1 H x i 2 + ϵ \hat{x}_i = \frac{x_i}{\sqrt{\frac{1}{H} \sum_{i=1}^{H} x_i^2 + \epsilon}} x^i=H1i=1Hxi2+ϵ xi
RMSNorm 省略了均值的计算,仅使用均方根值进行归一化。

RMSNorm 的优点

RMSNorm 与 LayerNorm 相比,具有以下优势:

  1. 计算效率更高:RMSNorm 省略了计算均值的步骤,仅需计算平方均值,减少了约 15% 的计算量。
  2. 数值稳定性更好:由于不涉及均值计算,RMSNorm 在某些情况下可以避免均值归一化导致的梯度消失问题。
  3. 适用于 Transformer 架构:在 Transformer 等对计算效率敏感的场景中,RMSNorm 可以显著加速训练。

RMSNorm 的实现

接下来,我们来看看如何用 Python 实现 RMSNorm。

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        # eps 是为了防止除以 0 的情况
        self.eps = eps
        # weight 是一个可学习的参数,全部初始化为 1
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        # 计算 RMSNorm 的核心部分
        # x.pow(2).mean(-1, keepdim=True) 计算了输入 x 的平方的均值
        # torch.rsqrt 是平方根的倒数,这样就得到了 RMSNorm 的分母部分,再加上 eps 防止分母为 0
        # 最后乘以 x,得到 RMSNorm 的结果
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        # forward 函数是模型的前向传播
        # 首先将输入 x 转为 float 类型,然后进行 RMSNorm,最后再转回原来的数据类型
        # 最后乘以 weight,这是 RMSNorm 的一个可学习的缩放因子
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

RMSNorm 的关键步骤

参数说明
  • dim:输入数据的特征维度。例如,如果输入数据的形状是 (batch_size, sequence_length, dim),则 dim 是最后一个维度的大小。
  • eps:一个非常小的数值,用于防止分母为零,确保数值稳定性。
__init__ 关键操作
  • self.eps:存储 eps 值,用于后续的归一化计算。
  • self.weight:定义一个可学习的参数 weight,其初始值为全1。这个参数在归一化后对输出进行缩放。
_norm 关键操作
  1. 计算平方的均值

    x.pow(2).mean(-1, keepdim=True)
    
    • x.pow(2):计算输入张量 x 的每个元素的平方。
    • .mean(-1, keepdim=True):沿着最后一个维度(特征维度)计算均值,并保持输出的维度与输入相同。
  2. 计算平方根的倒数

    torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    
    • torch.rsqrt:计算平方根的倒数,即 1 value \frac{1}{\sqrt{\text{value}}} value