如何看待PyTorch 2.0?

时间:2023-01-04 12:11:08

如何看待PyTorch 2.0?
 

作者|吴育昕
 

1

为什么是TorchDynamo
 

Graph capture 把用户 Python 写的模型代码变成 graph,是一切编译的根基。而 PyTorch 在试了这么多方案之后似乎已经锁定 TorchDynamo 作为 graph capture 的未来方向了,所以写一点关于 TorchDynamo 的内容,主要是解释到底为什么要做这个东西(离开FB一年了,内容主要凭自己的猜测和理解)。

 

一句话尽量解释 TorchDynamo 干了什么:利用 PEP523(https://peps.python.org/pep-0523/) 的 API 在用户执行每个 python frame 前, 拿到这个 frame 的 bytecode,把其中认识的部分用 tracing 的方式提取出 graph (并送给后端编译)不认识的部分维持原样。把修改后的 bytecode还给 CPython 跑。

 

由于 LazyTensor 和 TorchDynamo 都做 tracing,并且都是 best-effort graph capture,即只编译自己能 capture 的部分,capture 不到的用 Python 跑 (aka Python fallback),所以观感上两者可能会差不多。

 

然而,这两个方案的差别正是 TorchDynamo 关键的地方:

 

LazyTensor 是个纯靠 tracing 的方案,不可避免的问题是「只能看见 trace 到的部分,只有 trace 一下才知道哪里不能 trace」。而每次执行模型的时候,不能 trace 的部分可能不太一样。为了保证正确性,LazyTensor 就不得不每次执行都要重新 trace。举个极端的例子,模型里写了一个torch.add(tensor, random.random()) ,其中 random 是个 LazyTensor 看不见摸不着的 Python 函数,那只有重新 trace 才能保证正确性。

 

而当 TorchDynamo 修改 bytecode 的时候,事情就不太一样了:

 

  1. 在 bytecode 里能够看得见所有需要的信息,所以能够证明「这段模型代码没有用到奇怪的东西所以不需要重新 trace」。

     

  2. 光证明了「不需要 trace」不代表可以真的不 trace因为用户的代码还是一行行给 Python 来跑的。但是 TorchDynamo 又来了:CPython 到底跑什么 bytecode 是可以被它换掉的!

 

因此它可以做到这么一件事:当用户 call 一个被 capture 过的模型时模型里大部分 Python 代码都相当于不存在了,连 symbolic execution 的 overhead 都没有而被换成了编译后的 native code。这一点在以前所有的 partial graph capture 的方案里是做不到的:
 

  • LazyTensor 即使编译过的 graph 也要每次重新在 Python 里 trace 一遍,才能发现「哦,这个 graph 我曾见过的」。

  • @torch.jit.script 、@tf.function、 @jax.jit 可以把装饰的 python code 换成编译后的,但是这都依赖用户把这个 subgraph refactor 出来放到一个单独的函数里。而 TorchDynamo 是全自动不需要用户改代码的。

 

  • 这种 refactor 除了增加额外的工作量之外还可能与用户的代码结构冲突,因为 「用来编译的graph的边界」与「用户代码需要的抽象边界」很可能不 match:例如用户本来希望写三个函数但是最佳的优化是把其中两个半函数变成一个 graph这会让用户很尴尬。

这只是一个最直接的例子。由于能够读写 bytecode,理论上 TorchDynamo 能 access 更多 LazyTensor 根本没有的信息做更多事情(后面会提到)。而读写 bytecode 的难度比 source code要低不少所以成为了一个可行的方案。

 

2
whole-graph capture用处不大?

 

有的人可能会说上面提到的东西对 whole-graph capture 没太大用啊。 

我觉得确实是这样:TorchDynamo 是一个对 partial-graph capture 追求极致的方案能够对几乎所有的 Python 实现的模型开箱即用有加速不用改代码——前提是还要跑 Python 作为 fallback。但是部署一般需要的是 whole-graph capture 整个模型在一个 graph 里不能用 Python。

 

用 tracing 做 whole-graph capture 的前提是用户要在 Python 代码里避免所有不能被 trace 的东西最常见的用户要做的三件事是:使用 symbolic shape使用 symbolic control flow,禁用除了当前 tensor library之外的所有其它 library。如果用户做到了这些那只要一个普通的 symbolic tracing 就能 capture 到完整的 graph 了不需要 TorchDynamo 这么复杂的机制。TorchDynamo 可能可以略微简化用户做这些的工作量但我感觉不会有本质不同。

 

我个人的观点是从实用角度出发要求用户做上面几件事不算是太复杂的要求:禁用其他 library 理所应当就不说了;即使今天 PyTorch 还没有很好的 symbolic {shape, control flow}但是只要用 @torch.jit.script_if_tracing 来处理少量的 symbolic shape 和 symbolic control flow大多数模型都是可以正确的被 torch.jit.tracecapture 的。Meta 应该有几十上百个 vision 模型实现在 detectron2/d2go 里, 目前基本都是走这条路部署的(我另有篇文章https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/介绍这里面的细节)。

 

TensorFlow 的 whole-graph capture 就简单了:TF 从第一天就有很好的 symbolic shape 和 symbolic control flow,用就完了。tf.autograph 甚至还自动化了一部分 control flow 的改写工作。

 

所以用户少量改代码仍然是必须的。当然,TorchDynamo 毕竟有着"改变用户要跑的 bytecode" 的超能力。所以如果愿意的话,理论上可以让用户的 whole-graph capture 工作变得更简单。例如:
 

  • 模型中间的一些像 if x.shape[0] > 100 的分支有的可以通过 shape inference 等价转移到模型开头的。这样的话就可以 capture 到更大的没有分支的 subgraph。 这件事在 TorchDynamo 里现在叫做 "guard"。 

     

  • 理论上可以把 python control flow 自动替换成 symbolic 的,类似tf.autograph 做的事情只不过输入是 bytecode 而不是 source code。 
     

目前 TorchDynamo 的 "nopython" 模式就是 whole-graph capture 了。不过似乎还不是工作重心 (以下引用自https://docs.google.com/document/d/1tlgPcR2YmC3PcQuYDPUORFmEaBPQEmo8dsh4eUjnlyI/edit#heading=h.rmxeybu31e0):

 

PT2 will provide infrastructure for a no python export mode for edge and performance sensitive serving cases. The PT2 team won’t drive this end to end stack, but we will keep a feedback loop with the teams in charge of this and ensure the components we build are reusable in these situations.

 

不过与此同时PyTorch 2.0 最近在完善 symbolic shape 的支持;functorch 里也加入了少量 control flow operator。这算是利好 whole-graph capture 的消息。

 

3
总结

 

总的来说由于 TorchDynamo 在 bytecode 层面做文章能做到一些其他方案做不到的事情。它的优点主要为 partial graph capture 服务: 让用户的 Python 模型代码在 0 修改的情况下就能 capture 并获得加速。这体现了 PyTorch 对于 "Python first" 哲学的执念。这种执着是否有必要见仁见智。

 

TorchDynamo 的主要优势来自对 bytecode 的读写。JIT scripting compiler 的失败表明在 source code level 做不了太多事TorchDynamo 能在 bytecode level 做事情确实很巧妙。 不过 要完整的复刻 CPython bytecode interpreter 它的工作量、 维护难度(以及出 bug 的概率)都是不小的。

 

另外TorchDynamo 对 whole-graph capture 没有很大的帮助。 对于复杂的模型用户该做的改写还是得做。不过我估计 2.0 至少能对「用户该做什么」有个清晰的说法。

 

当然最后 PT2 到底能不能把 compiler 做好还有很多其他因素:IR 怎么设计何时specialize/recompile,各种 backend 不同的特性等等。比如 TorchDynamo 和 LazyTensor 使用的 IR 其实也不一样。但是本文只讨论 graph capture,其他问题就不提了。

(本文经授权后发布。原文:https://www.zhihu.com/question/570220953/answer/2798657470)

 

其他人都在看

欢迎Star、试用OneFlow最新版本:https://github.com/Oneflow-Inc/oneflow/


如何看待PyTorch 2.0?