调教好的模型不要扔,留着以后直接用!

时间:2022-12-21 14:14:03

之前我们已经学习了如何创建神经网络如何加载数据集如何训练模型

我们要知道,训练一个模型是要消耗很多算力资源的,模型越大消耗的人力物力财力越大,所以我们要避免重复造*,模型训练好了以后我们要学会将其保存下来。这一节我们就讲一下如何保存模型,以及在之后使用的时候如何加载我们保存的模型。

import torch
import torchvision.models as models

简单导包……

保存模型的参数

我们训练一个模型,让模型学习,其实就是让模型学习并更新其参数,所以保存模型其实就是将模型的参数保存下来。

PyTorch保存模型就是将学习好的参数使用torch.save保存到内部状态字典state_dict中。

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

上面代码就是保存vgg16这个网络的参数,之后pytorch就会自动为你保存模型参数了。显示100%之后就可以看一下你目录下边有无model_weights.pth文件了。

调教好的模型不要扔,留着以后直接用!

注意,model = models.vgg16(pretrained=True)中有个pretrained=True,我们这里用的是人家训练好的vgg16模型,就是我们保存的模型就是人家已经训练的vgg16的参数。

保存模型会显示如下:

调教好的模型不要扔,留着以后直接用!

加载模型参数

要加载模型参数之前,你首先需要实例化一个同样的网络结构,然后使用load_state_dict()方法加载模型参数即可。

model = models.vgg16() 
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

可以看到这里第一句是model = models.vgg16(),括号里是没有pretrained=True的。因为这里我们不需要人家的网络了,我们是要把我们之前保存的参数加载进来,所以我们只需要声明我们是什么网络结构就可以了,之后就是使用model.load_state_dict(torch.load('model_weights.pth文件中加载进来。'))将我们自己的参数从model_weights.pth文件中加载进来。

注意,在使用模型做预测之前一定要调用model.eval()!!! 这样将模型中的dropout和batch normalization层设置为评估模式,否则在预测阶段会出现不一样的结果。

如果使用jupyter notebook,它会给你显示模型结构,使用其他IDE可能需要手动打印一下。

调教好的模型不要扔,留着以后直接用!

保存并加载模型结构

当我们加载模型参数的时候,我们首先要实例化模型这个类,因为这个类定义了模型网络结构。

你可能想说,在这里是直接用的人家vgg16,那如果像之前那样我们的自定义网络怎么办?不就没办法直接用model = models.vgg16()加载模型结构了?所以我们还需要保存模型结构。保存好模型之后,我们才可以像上一节一样,直接将参数传给模型。

# 保存模型结构
torch.save(model, 'model.pth')

# 加载模型结构
model = torch.load('model.pth')