chatGPT学习---Transformer代码实现2

时间:2022-10-16 01:08:31

下面我们来实现Transformer,在正式编写Transformer之前,我们先来看一下实现Transformer的一个小技巧,这个是我们看懂别人写的Transformer代码的一个关键。

1. 数据技巧

还记得我们在讲Transformer原理时,网络层输出 z \boldsymbol{z} z的计算公式:
z ( j ) = ∑ i = 1 d m c i , j ⋅ v ( i ) \boldsymbol{z}^{(j)} = \sum_{i=1}^{d_{m} } c_{i,j} \cdot \boldsymbol{v}^{(i)} z(j)=i=1dmci,jv(i)
是一个向量相加的运算,虽然我们可以直接求出结果,但是效率会比较低,通常的做法是将其变为一个矩阵的乘法运算。
为了简单起见,我们假设是求向量的平均值,最简单的方法如下所示:

    def startup(self, args={}):
        print('AppExp v0.0.1')
        torch.manual_seed(1337)
        B, T, C = 4, 8, 2 # B: batch_size;T:序列长度;C:通道数,即词汇维度;
        X = torch.randn(B, T, C)
        xbow1 = self.sum1(X, B, T, C)
        print(xbow1)
        xbow2 = self.sum2(X, B, T, C)
        rst = torch.allclose(xbow1, xbow2)
        print(f'比较结果:xbow1==xbow2 => {rst};')
        xbow3 = self.sum3(X, B, T, C)
        rst = torch.allclose(xbow1, xbow3)
        print(f'xbow1和xbow3是否相等?{rst};')

    def sum1(self, X, B, T, C):
        xbow = torch.zeros((B, T, C)) # bag of words
        for b in range(B):
            for t in range(T):
                xprev = X[b, :t+1] # (t, C)
                xbow[b, t] = torch.mean(xprev, 0) # (b, t)
        return xbow
    
    def sum2(self, X, B, T, C):
        wei = torch.tril(torch.ones(T, T)) # Note1
        wei = wei / wei.sum(1, keepdim=True) # Note2
        return wei @ X
    
    def sum3(self, X, B, T, C):
        tril = torch.tril(torch.ones(T, T))
        wei = torch.zeros((T, T))
        wei = wei.masked_fill(tril==0, float('-inf')) # Note3
        wei = F.softmax(wei, dim=-1) # Note4
        return wei @ X

以上三个方法的结果是相同的,但是计算效率一个比一个高,我们会将这个技巧用到Transformer的实现中。

  • Note1:torch.tril为一个下三角矩阵:
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
  • Note2:每个元素除以它所在行的和,如下所示:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
  • Note3:是两个矩阵相乘,wei的形状为(8, 8),X的形状为(4, 8, 2),根据张量乘法,wei的(8, 8)与X(8, 2)作传统意义上的矩阵乘法运算,形成一个新的(8, 2),最后再叠加成(4, 8, 2),我们以其中一个为例:
    [ 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.5000 0.5000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.3333 0.3333 0.3333 0.0000 0.0000 0.0000 0.0000 0.0000 0.2500 0.2500 0.2500 0.2500 0.0000 0.0000 0.0000 0.0000 0.2000 0.2000 0.2000 0.2000 0.2000 0.0000 0.0000 0.0000 0.1667 0.1667 0.1667 0.1667 0.1667 0.1667 0.0000 0.0000 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.0000 0.1250 0.1250 0.1250 0.1250 0.1250 0.1250 0.1250 0.1250 ] ⋅ [ [ w o r d 1 , 1 w o r d 1 , 2 ] [ w o r d 2 , 1 w o r d 2 , 2 ] [ w o r d 3 , 1 w o r d 3 , 2 ] [ w o r d 4 , 1 w o r d 4 , 2 ] [ w o r d 5 , 1 w o r d 5 , 2 ] [ w o r d 6 , 1 w o r d 6 , 2 ] [ w o r d 7 , 1 w o r d 7 , 2 ] [ w o r d 8 , 1 w o r d 8 , 2 ] ] = [ [ w o r d 1 , 1 w o r d 1 , 2 ] [ ( w o r d 1 , 1 + w o r d 2 , 1 ) ∗ 0.5 ( w o r d 1 , 2 + w o r d 2 , 2 ) ∗ 0.5 ] [ ( w o r d 1 , 1 + w o r d 2 , 1 + w o r d 3 , 1 ) ∗ 0.3333 ( w o r d 1 , 2 + w o r d 2 , 2 + w o r d 3 , 2 ) ∗ 0.3333 ] . . . . . . ] \begin{bmatrix} 1.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ 0.5000 & 0.5000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ 0.3333 & 0.3333 & 0.3333 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ 0.2500 & 0.2500 & 0.2500 & 0.2500 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ 0.2000 & 0.2000 & 0.2000 & 0.2000 & 0.2000 & 0.0000 & 0.0000 & 0.0000 \\ 0.1667 & 0.1667 & 0.1667 & 0.1667 & 0.1667 & 0.1667 & 0.0000 & 0.0000 \\ 0.1429 & 0.1429 & 0.1429 & 0.1429 & 0.1429 & 0.1429 & 0.1429 & 0.0000 \\ 0.1250 & 0.1250 & 0.1250 & 0.1250 & 0.1250 & 0.1250 & 0.1250 & 0.1250 \end{bmatrix} \cdot \begin{bmatrix} \begin{bmatrix} word_{1,1} & word_{1,2} \end{bmatrix} \\ \begin{bmatrix} word_{2,1} & word_{2,2} \end{bmatrix} \\ \begin{bmatrix} word_{3,1} & word_{3,2} \end{bmatrix} \\ \begin{bmatrix} word_{4,1} & word_{4,2} \end{bmatrix} \\ \begin{bmatrix} word_{5,1} & word_{5,2} \end{bmatrix} \\ \begin{bmatrix} word_{6,1} & word_{6,2} \end{bmatrix} \\ \begin{bmatrix} word_{7,1} & word_{7,2} \end{bmatrix} \\ \begin{bmatrix} word_{8,1} & word_{8,2} \end{bmatrix} \\ \end{bmatrix} = \\ \begin{bmatrix} \begin{bmatrix} word_{1,1} & word_{1,2} \end{bmatrix} \\ \begin{bmatrix} (word_{1,1} + word_{2,1})*0.5 & (word_{1,2} + word_{2,2})*0.5 \end{bmatrix} \\ \begin{bmatrix} (word_{1,1} + word_{2,1} + word_{3,1})*0.3333 & (word_{1,2} + word_{2,2} + word_{3,2})*0.3333 \end{bmatrix} \\ ...... \end{bmatrix} 1.00000.50000.33330.25000.20000.16670.14290.12500.00000.50000.33330.25000.20000.16670.14290.12500.00000.00000.33330.25000.20000.16670.14290.12500.00000.00000.00000.25000.20000.16670.14290.12500.00000.00000.00000.00000.20000.16670.14290.12500.00000.00000.00000.00000.00000.16670.14290.12500.00000.00000.00000.00000.00000.00000.14290.12500.00000.00000.00000.00000.00000.00000.00000.1250 [word1,1word1,2][word2,1word2,2][word3,1word3,2][word4,1word4,2][word5,1word5,2][word6,1word6,2][word7,1word7,2][word8,1word8,2] = [word1,1word1,2][(word1,1+word2,1)0.5(word1,2+word2,2)0.5][(word1,1+word2,1+word3,1)0.3333(word1,2+word2,2+word3,2)0.3333]......
  • Note3:
tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
  • Note4:我们知道 e − i n f = 0 e^{-inf}=0 einf=0,其余项均相同,所以求softmax后,得到与Note2处相同的矩阵:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

我们做这个操作的意义在于,我们当前位置的单词,假设他只跟它前面的单词有关,我们认为其关系就是所有前面的单词(包括其自身)相加然后再取平均。这样我们每个单词就不仅表示自己的信息,同时也包括了它之前所有单词的信息,从而有效的解决了前文中BigramLanguageModel中,只能看前面一个单词的限制条件。
为了后面做Transformer,我们对BigramLanguageModel类做了一些修改,完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.6

2. 自注意力机制

    def startup(self, args={}):
        print('AppExp v0.0.1')
        torch.manual_seed(1337)
        B, T, C = 4, 8, AppRegistry.n_embed # B: batch_size;T:序列长度;C:通道数,即词汇维度;
        X = torch.randn(B, T, C)
        self.self_attention(X, B, T, C)

    def self_attention(self, X, B, T, C):
        W_K = nn.Linear(C, AppRegistry.head_size, bias=False)
        W_Q = nn.Linear(C, AppRegistry.head_size, bias=False)
        W_V = nn.Linear(C, AppRegistry.head_size, bias=False)
        k = W_K(X) # (B, T, h) # Note1
        q = W_Q(X) # (B, T, h) # Note2
        wei = q @ k.transpose(-2, -1) / (AppRegistry.head_size**0.5) # (B, T, h) @ (B, h, T) => (B, T, T) # Note3
        tril = torch.tril(torch.ones(T, T))
        wei = wei.masked_fill(tril==0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        v = W_V(X)
        out = wei @ v
        print(f'out: {out.shape};')

根据上一节的内容,对于输入信号X,我们定义了三个权值矩阵: W K W^{K} WK W Q W^{Q} WQ W V W^{V} WV,其中K代表经过索引后便于查询的索引,Q代表要查找的内容,为了便于理解,我们以一个输入信号为例:
k = W K x ( 1 ) q = W Q x ( 2 ) \boldsymbol{k} = W^{K} \boldsymbol{x} \quad \quad \quad (1)\\ \boldsymbol{q} = W^{Q}\boldsymbol{x} \quad \quad \quad (2) k=WKx(1)q=WQx(2)
第i个单词查询对第j个单词的关联度为:
w e i = q ( i ) ⋅ k ( j ) d k wei = \frac{ \boldsymbol{q}^{(i)} \cdot \boldsymbol{k}^{(j)} } { \sqrt{ d_{k} } } wei=dk q(i)k(j)
其中 d k = 16 d_{k}=16 dk=16,为自注意力头的维度,除以其平方根的目的是为了求softmax时的值变得更平均一些。

3. 自注意力头

class Head(nn.Module):
    def __init__(self, block_size=8, n_embed=32, head_size=16):
        super(Head, self).__init__()
        self.W_K = nn.Linear(n_embed, head_size, bias = False)
        self.W_Q = nn.Linear(n_embed, head_size, bias = False)
        self.W_V = nn.Linear(n_embed, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(head_size, head_size)))

    def forward(self, X):
        B, T, C = X.shape
        k = self.W_K(X) # (B, T, h)
        q = self.W_Q(X) # (B, T, h)
        wei = (q @ k.transpose(-2, -1) / (AppRegistry.head_size**0.5)).to(AppRegistry.device) # (B, T, h) @ (B, h, T) => (B, T, T)
        tril = torch.tril(torch.ones(T, T)).to(AppRegistry.device)
        wei = wei.masked_fill(tril==0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        v = self.W_V(X)
        return wei @ v

向模型中添加自注意力头:

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 词汇数,单词维度
        self.token_embedding_table = nn.Embedding(AppRegistry.vocab_size, AppRegistry.n_embed)
        self.position_embedding_table = nn.Embedding(AppRegistry.block_size, AppRegistry.n_embed)
        self.sa_head = Head(head_size=AppRegistry.n_embed, n_embed=AppRegistry.n_embed, block_size=AppRegistry.block_size)
        self.lm_head = nn.Linear(AppRegistry.n_embed, AppRegistry.vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C) C=n_embed
        pos_emb = self.position_embedding_table(torch.arange(T, device=AppRegistry.device))
        x = tok_emb + pos_emb # (B, T, C)
        x = self.sa_head(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -AppRegistry.block_size:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] # (B, T, C) => (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

为了验证程序的正确性,可以运行一下训练过程。完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.7

4. 多头机制

添加多头模型支持:

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([Head(AppRegistry.block_size, AppRegistry.n_embed, head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(AppRegistry.n_embed, AppRegistry.n_embed)

    def forward(self, X):
        X = torch.cat([h(X) for h in self.heads], dim=-1)
        return self.proj(X)

向BigramLanguageModel中添加支持多头模型:

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 词汇数,单词维度
        self.token_embedding_table = nn.Embedding(AppRegistry.vocab_size, AppRegistry.n_embed)
        self.position_embedding_table = nn.Embedding(AppRegistry.block_size, AppRegistry.n_embed)
        self.sa_heads = MultiHeadAttention(num_heads=4, head_size=AppRegistry.n_embed//4)
        self.lm_head = nn.Linear(AppRegistry.n_embed, AppRegistry.vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C) C=n_embed
        pos_emb = self.position_embedding_table(torch.arange(T, device=AppRegistry.device))
        x = tok_emb + pos_emb # (B, T, C)
        x = self.sa_heads(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits

完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.8

5. 添加前向传播网络

class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
        )

    def forward(self, X):
        return self.net(X)

使用前向传播网络:

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 词汇数,单词维度
        self.token_embedding_table = nn.Embedding(AppRegistry.vocab_size, AppRegistry.n_embed)
        self.position_embedding_table = nn.Embedding(AppRegistry.block_size, AppRegistry.n_embed)
        self.sa_heads = MultiHeadAttention(num_heads=4, head_size=AppRegistry.n_embed//4)
        self.ffwd = FeedForward(AppRegistry.n_embed)
        self.lm_head = nn.Linear(AppRegistry.n_embed, AppRegistry.vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C) C=n_embed
        pos_emb = self.position_embedding_table(torch.arange(T, device=AppRegistry.device))
        x = tok_emb + pos_emb # (B, T, C)
        x = self.sa_heads(x)
        x = self.ffwd(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits

完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.9

6. 添加Block

我们下面来添加Encoder的Block:

class TransformerEncoderBlock(nn.Module):
    def __init__(self, n_embed, n_head):
        super(TransformerEncoderBlock, self).__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed)

    def forward(self, X):
        X = self.sa(X)
        X = self.ffwd(X)
        return X

使用Block:

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 词汇数,单词维度
        self.token_embedding_table = nn.Embedding(AppRegistry.vocab_size, AppRegistry.n_embed)
        self.position_embedding_table = nn.Embedding(AppRegistry.block_size, AppRegistry.n_embed)
        self.blocks = nn.Sequential(
            TransformerEncoderBlock(n_embed=AppRegistry.n_embed, n_head=4),
            TransformerEncoderBlock(n_embed=AppRegistry.n_embed, n_head=4),
            TransformerEncoderBlock(n_embed=AppRegistry.n_embed, n_head=4),
        )
        self.lm_head = nn.Linear(AppRegistry.n_embed, AppRegistry.vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C) C=n_embed
        pos_emb = self.position_embedding_table(torch.arange(T, device=AppRegistry.device))
        x = tok_emb + pos_emb # (B, T, C)
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits

完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.10

7. 添加Residue连接和dropout

我们添加Residue连接和dropout,并且调整超参数的值,就可以得到最终版本,由于调整的地方比较多,而且比较杂,具体修改请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.1.0