Pytorch中张量的保存与加载
保存张量
在Pytorch中,一个约定俗成的方法是使用.pt扩展的文件格式来保存张量,使用的方法为()。
函数原型与参数说明
import torch
def save(obj, f: Union[str, os.PathLike, BinaryIO],
pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
"""
pytorch框架的原型代码
"""
pass
# 参数说明
# obj:要保存的对象,类型为tensor
# f:保存的文件名,可以是文件路径(包含文件名的字符串)、可以是字符流、也可以是文件对象
# pickle:Python中的一个模块,实现了用于序列化和反序列化Python对象结构的二进制协议
# pickle_module:用来协议化元数据和对象的协议
# pickle_protocol:可以指定来默认覆盖的协议
# 使用save方法
def save_tensor():
# 直接保存为一个张量
x = torch.Tensor([1, 2, 3])
torch.save(x, 'save_tensor.pt')
# 保存为字符流的格式
buffer = io.BytesIO()
torch.save(x, buffer)
加载张量
在Pytorch中,使用()方法加载()方法保存的文件。
函数原型与参数说明
import torch
def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
"""
Pytorch框架的原型代码
"""
pass
# 参数说明
# f:保存的文件名
# map_location:加载位置,即将这个张量加载到哪,可选的内容包括:函数、、字符串以及指定如何重新映射存储的字典
# pickle_module:用来协议化元数据和对象的协议
# pickle_load_args:需要加载的pickle模块的参数设置。这个包含的内容相当丰富,感兴趣的可以去阅读Pytorch的官方手册
# 使用load方法
def tensor_load():
# 小白式加载(最常用)
torch.load('save_tensor.pt')
# 加载到CPU中
torch.load('save_tensor.pt', map_location=torch.device('cpu'))
# 使用函数加载到CPU中
torch.load('save_tensor.pt', map_location=lambda storage, loc: storage)
# 加载到GPU1中
torch.load('save_tensor.pt', map_location=lambda storage, loc: storage.cuda(1))
# 从GPU0加载到GPU1中
torch.load('save_tensor.pt', map_location={'cuda: 1': 'cuda: 0'})
# 指定加载的编码方式
torch.load('save_tensor.pt', encoding='ascii')
# 加载字符流格式的张量
with open('save_tensor.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
torch.load(buffer)