MLP分类效果一般好于线性分类器,即将特征输入MLP中再经过softmax来进行分类。
具体实现为将原先线性分类模块:
1
|
self .classifier = nn.Linear(config.hidden_size, num_labels)
|
替换为:
1
|
self .classifier = MLP(config.hidden_size, num_labels)
|
并且添加MLP模块:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
class MLP(nn.Module):
def __init__( self , input_size, common_size):
super (MLP, self ).__init__()
self .linear = nn.Sequential(
nn.Linear(input_size, input_size / / 2 ),
nn.ReLU(inplace = True ),
nn.Linear(input_size / / 2 , input_size / / 4 ),
nn.ReLU(inplace = True ),
nn.Linear(input_size / / 4 , common_size)
)
def forward( self , x):
out = self .linear(x)
return out
|
看一下模块结构:
1
2
|
mlp = MLP( 1000 , 3 )
print (mlp)
|
以上这篇关于Pytorch的MLP模块实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_33373858/article/details/88108153