[WIP] einops版GQA MSA

时间:2025-04-12 06:58:35
import torch import torch.nn as nn import math from einops import rearrange class MyGQA(nn.Module): def __init__(self, nheads, dim, ngroups): super().__init__() self.head_dim = dim // nheads self.nheads = nheads self.dim = dim self.ngroups = ngroups self.heads_per_group = nheads // ngroups # dim = self.head_dim * nheads # dim = self.head_dim * self.heads_per_group * ngroups self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim // self.heads_per_group) self.v_proj = nn.Linear(dim, dim // self.heads_per_group) self.o_proj = nn.Linear(dim, dim) self.ln = nn.LayerNorm(dim) def forward(self, query, key, value, attn_mask = None): bs,q_len,dim = query.shape # q = self.q_proj(query).reshape(bs, q_len, self.nheads, self.head_dim).transpose(1,2).reshape(bs, self.nheads, q_len, self.head_dim) # k = self.k_proj(key).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim) # v = self.v_proj(value).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim) q = rearrange(self.q_proj(query), 'b l (head k) -> b head l k', head=self.nheads) k = rearrange(self.k_proj(key).repeat_interleave(self.heads_per_group, dim=0), '(b heads_per_group) l (ngroups k) -> b (heads_per_group ngroups) l k', heads_per_group=self.heads_per_group, ngroups=self.ngroups) v = rearrange(self.v_proj(value).repeat_interleave(self.heads_per_group, dim=0), '(b heads_per_group) l (ngroups k) -> b (heads_per_group ngroups) l k', heads_per_group=self.heads_per_group, ngroups=self.ngroups) attn = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.head_dim) if attn_mask is not None: attn = attn.masked_fill(attn_mask == 0, float('-inf')) attn = attn.softmax(dim=-1) output = torch.matmul(attn, v) # bs,nheads,q_len,head_dim output = self.o_proj(rearrange(output, 'b head l k -> b l (head k)')) # output = self.o_proj(output.transpose(1,2).reshape(bs, q_len, self.nheads*self.head_dim)) return output, attn class MyGQA2(nn.Module): def __init__(self, nheads, dim, ngroups): super().__init__() self.head_dim = dim // nheads self.nheads = nheads self.dim = dim self.ngroups = ngroups self.heads_per_group = nheads // ngroups # dim = self.head_dim * nheads # dim = self.head_dim * self.heads_per_group * ngroups self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim // self.heads_per_group) self.v_proj = nn.Linear(dim, dim // self.heads_per_group) self.o_proj = nn.Linear(dim, dim) def forward(self, query, key, value, attn_mask = None): bs,q_len,dim = query.shape q = self.q_proj(query).reshape(bs, q_len, self.nheads, self.head_dim).transpose(1,2).reshape(bs, self.nheads, q_len, self.head_dim) k = self.k_proj(key).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim) v = self.v_proj(value).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim) attn = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.head_dim) if attn_mask is not None: attn = attn.masked_fill(attn_mask == 0, float('-inf')) attn = attn.softmax(dim=-1) output = torch.matmul(attn, v) # bs,nheads,q_len,head_dim output = self.o_proj(output.transpose(1,2).reshape(bs, q_len, self.nheads*self.head_dim)) return output, attn if __name__ == '__main__': embed_dim,num_heads,num_groups=256,8,4 q_len,bs = 2,3 query = torch.randn(bs, q_len, embed_dim) key = torch.randn(bs, q_len, embed_dim) value = torch.randn(bs, q_len, embed_dim) my_multihead_attn = MyGQA(num_heads, embed_dim, num_groups) for param in my_multihead_attn.parameters(): param.data.fill_(0.1) my_attn_output, my_attn_output_weights = my_multihead_attn(query, key, value) print('my_attn_output={}'.format(my_attn_output)) my_multihead_attn2 = MyGQA2(num_heads, embed_dim, num_groups) for param in my_multihead_attn2.parameters(): param.data.fill_(0.1) my_attn_output2, my_attn_output_weights2 = my_multihead_attn2(query, key, value) print('my_attn_output2={}'.format(my_attn_output2)) max_diff = torch.max(torch.abs(my_attn_output - my_attn_output2)).item() print(torch.equal(my_attn_output_weights, my_attn_output_weights2)) print('max_diff={}'.format(max_diff))