F.Linear() 和 nn.Linear() 的区别

时间:2024-01-29 17:13:56

F.Linear()

源码

可以看到nn.Linear内部调用了F.Linear,相当于是将其封装了,并自动地对参数进行了初始化。如果我们想自己初始化参数,那么可以不用nn.Linear。

为了灵活地对参数按照自己的方式进行初始化,可以借鉴fairseq的初始化做法

	def reset_parameters(self):
               nn.init.xavier_uniform_(self.in_proj_weight)
               nn.init.xavier_uniform_(self.out_proj.weight)
               if self.in_proj_bias is not None:
                   nn.init.constant_(self.in_proj_bias, 0.)
                   nn.init.constant_(self.out_proj.bias, 0.)
               if self.bias_k is not None:
                   nn.init.xavier_normal_(self.bias_k)
               if self.bias_v is not None:	
                   nn.init.xavier_normal_(self.bias_v)
 self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
 if bias:
      self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
 else:
      self.register_parameter(\'in_proj_bias\', None)
 self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
     self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))        
     self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
 else:
     self.bias_k = self.bias_v = None