先说结论
model.state_dict()
是浅拷贝,返回的参数仍然会随着网络的训练而变化。
应该使用deepcopy(model.state_dict())
,或将参数及时序列化到硬盘。
再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。
补充:pytorch中state_dict的理解
在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
其实看了如下代码的输出应该就懂了
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
def __init__( self ):
super (TheModelClass, self ).__init__()
self .conv1 = nn.Conv2d( 3 , 6 , 5 )
self .pool = nn.MaxPool2d( 2 , 2 )
self .conv2 = nn.Conv2d( 6 , 16 , 5 )
self .fc1 = nn.Linear( 16 * 5 * 5 , 120 )
self .fc2 = nn.Linear( 120 , 84 )
self .fc3 = nn.Linear( 84 , 10 )
def forward( self , x):
x = self .pool(F.relu( self .conv1(x)))
x = self .pool(F.relu( self .conv2(x)))
x = x.view( - 1 , 16 * 5 * 5 )
x = F.relu( self .fc1(x))
x = F.relu( self .fc2(x))
x = self .fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr = 0.001 , momentum = 0.9 )
# Print model's state_dict
print ( "Model's state_dict:" )
for param_tensor in model.state_dict():
print (param_tensor, "\t" , model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print ( "Optimizer's state_dict:" )
for var_name in optimizer.state_dict():
print (var_name, "\t" , optimizer.state_dict()[var_name])
|
输出如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
Model's state_dict:
conv1.weight torch.Size([ 6 , 3 , 5 , 5 ])
conv1.bias torch.Size([ 6 ])
conv2.weight torch.Size([ 16 , 6 , 5 , 5 ])
conv2.bias torch.Size([ 16 ])
fc1.weight torch.Size([ 120 , 400 ])
fc1.bias torch.Size([ 120 ])
fc2.weight torch.Size([ 84 , 120 ])
fc2.bias torch.Size([ 84 ])
fc3.weight torch.Size([ 10 , 84 ])
fc3.bias torch.Size([ 10 ])
Optimizer's state_dict:
state {}
param_groups [{ 'lr' : 0.001 , 'momentum' : 0.9 , 'dampening' : 0 , 'weight_decay' : 0 , 'nesterov' : False , 'params' : [ 2238501264336 , 2238501329800 , 2238501330016 , 2238501327136 , 2238501328576 , 2238501329728 , 2238501327928 , 2238501327064 , 2238501330808 , 2238501328288 ]}]
|
我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!
补充:pytorch保存模型时报错***object has no attribute 'state_dict'
定义了一个类BaseNet并实例化该类:
1
|
net = BaseNet()
|
保存net时报错 object has no attribute 'state_dict'
1
|
torch.save(net.state_dict(), models_dir)
|
原因是定义类的时候不是继承nn.Module类,比如:
1
2
|
class BaseNet( object ):
def __init__( self ):
|
把类定义改为
1
2
3
|
class BaseNet(nn.Module):
def __init__( self ):
super (BaseNet, self ).__init__()
|
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。如有错误或未考虑完全的地方,望不吝赐教。
原文链接:https://www.cnblogs.com/LukeStepByStep/p/11248361.html