作者|吴育昕
1
为什么是TorchDynamo
Graph capture 把用户 Python 写的模型代码变成 graph,是一切编译的根基。而 PyTorch 在试了这么多方案之后似乎已经锁定 TorchDynamo 作为 graph capture 的未来方向了,所以写一点关于 TorchDynamo 的内容,主要是解释到底为什么要做这个东西(离开FB一年了,内容主要凭自己的猜测和理解)。
一句话尽量解释 TorchDynamo 干了什么:利用 PEP523(https://peps.python.org/pep-0523/) 的 API 在用户执行每个 python frame 前, 拿到这个 frame 的 bytecode,把其中认识的部分用 tracing 的方式提取出 graph (并送给后端编译),不认识的部分维持原样。把修改后的 bytecode还给 CPython 跑。
由于 LazyTensor 和 TorchDynamo 都做 tracing,并且都是 best-effort graph capture,即只编译自己能 capture 的部分,capture 不到的用 Python 跑 (aka Python fallback),所以观感上两者可能会差不多。
然而,这两个方案的差别正是 TorchDynamo 关键的地方:
LazyTensor 是个纯靠 tracing 的方案,不可避免的问题是「只能看见 trace 到的部分,只有 trace 一下才知道哪里不能 trace」。而每次执行模型的时候,不能 trace 的部分可能不太一样。为了保证正确性,LazyTensor 就不得不每次执行都要重新 trace。举个极端的例子,模型里写了一个torch.add(tensor, random.random()) ,其中 random 是个 LazyTensor 看不见摸不着的 Python 函数,那只有重新 trace 才能保证正确性。
而当 TorchDynamo 修改 bytecode 的时候,事情就不太一样了:
-
在 bytecode 里能够看得见所有需要的信息,所以能够证明「这段模型代码没有用到奇怪的东西所以不需要重新 trace」。
-
光证明了「不需要 trace」不代表可以真的不 trace,因为用户的代码还是一行行给 Python 来跑的。但是 TorchDynamo 又来了:CPython 到底跑什么 bytecode 是可以被它换掉的!
因此它可以做到这么一件事:当用户 call 一个被 capture 过的模型时,模型里大部分 Python 代码都相当于不存在了,连 symbolic execution 的 overhead 都没有,而被换成了编译后的 native code。这一点在以前所有的 partial graph capture 的方案里是做不到的:
-
LazyTensor 即使编译过的 graph 也要每次重新在 Python 里 trace 一遍,才能发现「哦,这个 graph 我曾见过的」。
-
@torch.jit.script 、@tf.function、 @jax.jit 可以把装饰的 python code 换成编译后的,但是这都依赖用户把这个 subgraph refactor 出来放到一个单独的函数里。而 TorchDynamo 是全自动不需要用户改代码的。
-
这种 refactor 除了增加额外的工作量之外,还可能与用户的代码结构冲突,因为 「用来编译的graph的边界」与「用户代码需要的抽象边界」很可能不 match:例如用户本来希望写三个函数,但是最佳的优化是把其中两个半函数变成一个 graph,这会让用户很尴尬。
这只是一个最直接的例子。由于能够读写 bytecode,理论上 TorchDynamo 能 access 更多 LazyTensor 根本没有的信息,做更多事情(后面会提到)。而读写 bytecode 的难度比 source code要低不少,所以成为了一个可行的方案。
2
whole-graph capture用处不大?
有的人可能会说,上面提到的东西对 whole-graph capture 没太大用啊。
我觉得确实是这样:TorchDynamo 是一个对 partial-graph capture 追求极致的方案,能够对几乎所有的 Python 实现的模型开箱即用有加速,不用改代码——前提是还要跑 Python 作为 fallback。但是部署一般需要的是 whole-graph capture 整个模型在一个 graph 里不能用 Python。
用 tracing 做 whole-graph capture 的前提是用户要在 Python 代码里避免所有不能被 trace 的东西,最常见的用户要做的三件事是:使用 symbolic shape,使用 symbolic control flow,禁用除了当前 tensor library之外的所有其它 library。如果用户做到了这些,那只要一个普通的 symbolic tracing 就能 capture 到完整的 graph 了,不需要 TorchDynamo 这么复杂的机制。TorchDynamo 可能可以略微简化用户做这些的工作量,但我感觉不会有本质不同。
我个人的观点是,从实用角度出发,要求用户做上面几件事不算是太复杂的要求:禁用其他 library 理所应当就不说了;即使今天 PyTorch 还没有很好的 symbolic {shape, control flow},但是只要用 @torch.jit.script_if_tracing 来处理少量的 symbolic shape 和 symbolic control flow,大多数模型都是可以正确的被 torch.jit.tracecapture 的。Meta 应该有几十上百个 vision 模型实现在 detectron2/d2go 里, 目前基本都是走这条路部署的(我另有篇文章https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/介绍这里面的细节)。
TensorFlow 的 whole-graph capture 就简单了:TF 从第一天就有很好的 symbolic shape 和 symbolic control flow,用就完了。tf.autograph 甚至还自动化了一部分 control flow 的改写工作。
所以,用户少量改代码仍然是必须的。当然,TorchDynamo 毕竟有着"改变用户要跑的 bytecode" 的超能力。所以如果愿意的话,理论上可以让用户的 whole-graph capture 工作变得更简单。例如:
-
模型中间的一些像 if x.shape[0] > 100 的分支,有的可以通过 shape inference 等价转移到模型开头的。这样的话就可以 capture 到更大的没有分支的 subgraph。 这件事在 TorchDynamo 里现在叫做 "guard"。
-
理论上可以把 python control flow 自动替换成 symbolic 的,类似tf.autograph 做的事情,只不过输入是 bytecode 而不是 source code。
目前 TorchDynamo 的 "nopython" 模式就是 whole-graph capture 了。不过似乎还不是工作重心 (以下引用自https://docs.google.com/document/d/1tlgPcR2YmC3PcQuYDPUORFmEaBPQEmo8dsh4eUjnlyI/edit#heading=h.rmxeybu31e0):
PT2 will provide infrastructure for a no python export mode for edge and performance sensitive serving cases. The PT2 team won’t drive this end to end stack, but we will keep a feedback loop with the teams in charge of this and ensure the components we build are reusable in these situations.
不过与此同时,PyTorch 2.0 最近在完善 symbolic shape 的支持;functorch 里也加入了少量 control flow operator。这算是利好 whole-graph capture 的消息。
3
总结
总的来说,由于 TorchDynamo 在 bytecode 层面做文章,能做到一些其他方案做不到的事情。它的优点主要为 partial graph capture 服务: 让用户的 Python 模型代码在 0 修改的情况下就能 capture 并获得加速。这体现了 PyTorch 对于 "Python first" 哲学的执念。这种执着是否有必要,见仁见智。
TorchDynamo 的主要优势来自对 bytecode 的读写。JIT scripting compiler 的失败表明在 source code level 做不了太多事,TorchDynamo 能在 bytecode level 做事情确实很巧妙。 不过 , 要完整的复刻 CPython bytecode interpreter , 它的工作量、 维护难度(以及出 bug 的概率)都是不小的。
另外,TorchDynamo 对 whole-graph capture 没有很大的帮助。 对于复杂的模型,用户该做的改写还是得做。不过我估计 2.0 至少能对「用户该做什么」有个清晰的说法。
当然,最后 PT2 到底能不能把 compiler 做好,还有很多其他因素:IR 怎么设计,何时specialize/recompile,各种 backend 不同的特性等等。比如 TorchDynamo 和 LazyTensor 使用的 IR 其实也不一样。但是本文只讨论 graph capture,其他问题就不提了。
(本文经授权后发布。原文:https://www.zhihu.com/question/570220953/answer/2798657470)
其他人都在看
欢迎Star、试用OneFlow最新版本:https://github.com/Oneflow-Inc/oneflow/