我就废话不多说了,直接上代码吧!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
class Net(nn.Module):
def __init__( self , model):
super (Net, self ).__init__()
#取掉model的后两层
self .resnet_layer = nn.Sequential( * list (model.children())[: - 2 ])
self .transion_layer = nn.ConvTranspose2d( 2048 , 2048 , kernel_size = 14 , stride = 3 )
self .pool_layer = nn.MaxPool2d( 32 )
self .Linear_layer = nn.Linear( 2048 , 8 )
def forward( self , x):
x = self .resnet_layer(x)
x = self .transion_layer(x)
x = self .pool_layer(x)
x = x.view(x.size( 0 ), - 1 )
x = self .Linear_layer(x)
return x
|
1
2
|
resnet = models.resnet50(pretrained = True )
model = Net(resnet)
|
以上这篇pytorch 修改预训练model实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/whut_ldz/article/details/78874977