PyTorch 1.0 中文文档:Torch 脚本

时间:2021-10-07 06:52:22

译者:keyianpai

Torch脚本是一种从PyTorch代码创建可序列化和可优化模型的方法。用Torch脚本编写的代码可以从Python进程中保存,并在没有Python依赖的进程中加载。

我们提供了一些工具帮助我们将模型从纯Python程序逐步转换为可以独立于Python运行的Torch脚本程序。Torch脚本程序可以在其他语言的程序中运行(例如,在独立的C ++程序中)。这使得我们可以使用熟悉的工具在PyTorch中训练模型,而将模型导出到出于性能和多线程原因不能将模型作为Python程序运行的生产环境中去。

class torch.jit.ScriptModule(optimize=True)

ScriptModule与其内部的Torch脚本函数可以通过两种方式创建:

追踪:

使用torch.jit.trace。torch.jit.trace以现有模块或python函数和样例输入作为参数,它会运行该python函数,记录函数在所有张量上执行的操作,并将记录转换为Torch脚本方法以作为ScriptModule的forward方法。创建的模块包含原始模块的所有参数。

例:

import torch
def foo(x, y):
    return 2*x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

阅读全文/改进本文