pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。
两个有序字典找不同
模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。
1
2
3
4
5
6
7
8
9
10
11
|
model = ResNet18( 1 )
model_dict1 = torch.load( 'resnet18.pth' )
model_dict2 = model.state_dict()
model_list1 = list (model_dict1.keys())
model_list2 = list (model_dict2.keys())
len1 = len (model_list1)
len2 = len (model_list2)
minlen = min (len1, len2)
for n in range (minlen):
if model_dict1[model_list1[n]].shape ! = model_dict2[model_list2[n]].shape:
err = 1
|
自己搭建模型的注意事项
搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。
1
2
3
4
5
6
7
8
9
10
11
12
13
|
model = ResNet18( 1 )
model_dict1 = torch.load( 'resnet18.pth' )
model_dict2 = model.state_dict()
model_list1 = list (model_dict1.keys())
model_list2 = list (model_dict2.keys())
len1 = len (model_list1)
len2 = len (model_list2)
minlen = min (len1, len2)
for n in range (minlen):
if model_dict1[model_list1[n]].shape ! = model_dict2[model_list2[n]].shape:
continue
model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)
|
完整的代码见自己搭建resnet18网络并加载torchvision自带权重
新增的改进代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
model_dict1 = torch.load( 'yolov5.pth' )
model_dict2 = model.state_dict()
model_list1 = list (model_dict1.keys())
model_list2 = list (model_dict2.keys())
len1 = len (model_list1)
len2 = len (model_list2)
m, n = 0 , 0
while True :
if m > = len1 or n > = len2:
break
layername1, layername2 = model_list1[m], model_list2[n]
w1, w2 = model_dict1[layername1], model_dict2[layername2]
if w1.shape ! = w2.shape:
continue
model_dict2[layername2] = model_dict1[layername1]
m + = 1
n + = 1
model.load_state_dict(model_dict2)
|
如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。
补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配
看代码吧~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
def __init__( self , rgb_range):
super (Classification_att, self ).__init__()
self .vgg19 = models.vgg19(pretrained = True )
vgg = models.vgg19(pretrained = True ).features
conv_modules = [m for m in vgg]
self .vgg_conv = nn.Sequential( * conv_modules[: 37 ])
classfi = models.vgg19(pretrained = True ).classifier
classif_modules = [n for n in classfi]
self .vgg_class = nn.Sequential( * classif_modules[: 4 ])
vgg_mean = ( 0.485 , 0.456 , 0.406 )
vgg_std = ( 0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self .sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
for p in self .vgg_conv.parameters():
p.requires_grad = False
for p in self .vgg_class.parameters():
p.requires_grad = False
self .classifi = nn.Sequential(
nn.Linear( 4096 , 1024 ),
nn.ReLU( True ),
nn.Linear( 1024 , 256 ),
nn.ReLU( True ),
nn.Linear( 256 , 64 ),
)
def forward( self , x):
x = F.interpolate(x, size = [ 224 , 224 ], scale_factor = None , mode = 'bilinear' ,
align_corners = False )
x = self .sub_mean(x)
x = self .vgg_conv(x)
x = self .vgg_class(x) #执行这部报错,说张量不匹配
|
原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的
查看vgg的pytorch源码发现是
1
2
3
4
5
|
x = self .features(x)
x = self .avgpool(x)
x = torch.flatten(x, 1 )
x = self .classifier(x)
#自己的代码没有torch.flatten(x, 1)这步
|
所以自己的少了一步
1
|
x = torch.flatten(x, 1 )
|
补上就好了!
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_34288751/article/details/114160725