之前我们已经学习了如何创建神经网络,如何加载数据集,如何训练模型。
我们要知道,训练一个模型是要消耗很多算力资源的,模型越大消耗的人力物力财力越大,所以我们要避免重复造*,模型训练好了以后我们要学会将其保存下来。这一节我们就讲一下如何保存模型,以及在之后使用的时候如何加载我们保存的模型。
import torch
import torchvision.models as models
简单导包……
保存模型的参数
我们训练一个模型,让模型学习,其实就是让模型学习并更新其参数,所以保存模型其实就是将模型的参数保存下来。
PyTorch保存模型就是将学习好的参数使用torch.save
保存到内部状态字典state_dict
中。
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
上面代码就是保存vgg16这个网络的参数,之后pytorch就会自动为你保存模型参数了。显示100%之后就可以看一下你目录下边有无model_weights.pth文件了。
注意,model = models.vgg16(pretrained=True)
中有个pretrained=True
,我们这里用的是人家训练好的vgg16模型,就是我们保存的模型就是人家已经训练的vgg16的参数。
保存模型会显示如下:
加载模型参数
要加载模型参数之前,你首先需要实例化一个同样的网络结构,然后使用load_state_dict()
方法加载模型参数即可。
model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
可以看到这里第一句是model = models.vgg16()
,括号里是没有pretrained=True
的。因为这里我们不需要人家的网络了,我们是要把我们之前保存的参数加载进来,所以我们只需要声明我们是什么网络结构就可以了,之后就是使用model.load_state_dict(torch.load('model_weights.pth文件中加载进来。'))
将我们自己的参数从model_weights.pth文件中加载进来。
注意,在使用模型做预测之前一定要调用model.eval()
!!! 这样将模型中的dropout和batch normalization层设置为评估模式,否则在预测阶段会出现不一样的结果。
如果使用jupyter notebook,它会给你显示模型结构,使用其他IDE可能需要手动打印一下。
保存并加载模型结构
当我们加载模型参数的时候,我们首先要实例化模型这个类,因为这个类定义了模型网络结构。
你可能想说,在这里是直接用的人家vgg16,那如果像之前那样我们的自定义网络怎么办?不就没办法直接用model = models.vgg16()
加载模型结构了?所以我们还需要保存模型结构。保存好模型之后,我们才可以像上一节一样,直接将参数传给模型。
# 保存模型结构
torch.save(model, 'model.pth')
# 加载模型结构
model = torch.load('model.pth')