Jax计算框架的JIT编译的static特性

时间:2024-01-23 16:05:09

官方:
https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables




jax的单步操作是具有编译和缓存特性的,但是单步操作之间是需要切换为python操作的,因而会影响代码运行效率,为了提高运行效率jax可以通过jit函数对多个jax的单步操作合并成一个编译和缓存操作。这些单步操作被合并到单个函数中后被jit函数包装后不需要在中间步骤切换为Python代码,因此运行效率更高。


代码:

from jax import jit
import jax.numpy as jnp

import numpy as np

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

运行结果:

Jax计算框架的JIT编译的static特性_python

再次调用:

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)

运行结果:

Jax计算框架的JIT编译的static特性_缓存_02


可以看到,第一次运行时jax框架把jit函数内的操作进行了编译,然后再一次调用f函数时则不执行python代码,而是直接执行编译后的jax的后端代码。


但是要注意的是,jit编译的代码需要是静态static的,也就是jit编译的函数,其输入的参数必须是static的,即shape和type是静态的,这里的静态是指输入的参数在jit的编译的函数内其参数变量的shape和type是不能改变的(由于jax的array变量其内部值也是不能更改的,因此这点是基本的保证),并且依赖输入参数变量的其他变量的shape和type也是静态的,即只能依赖于输入变量的shape和type而不能随意变动。

换句话说,在jax的jit编译的函数内参数的shape和type如果变化也只能是依赖输入参数的type和shape的,而且与输入参数参数相关的代码结构只能是固定的或者说是依赖于输入参数的type和shape的,而像if判断这样需要依赖数值大小的语句是不能够依赖于输入参数的value的,否则在进行jit编译时是会报错的。

给出Demo:

Jax计算框架的JIT编译的static特性_缓存_03

这个例子中,jit包裹的函数内的判断语句依赖于输入参数的具体值的,因此无法进行jit编译,为此,我们可以把jit内的if判断语句依赖的数值设置为静态static的,以此进行jit编译。

修改后代码:

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  print("x:", x)
  print("neg:", neg)
  return -x if neg else x

f(111, False)


给出第一次运行和第二次运行结果:

Jax计算框架的JIT编译的static特性_缓存_04


可以看到,在第一次运算时进行了jit编译,因此会执行jit内部的python操作,打印出x和neg的数值,第二次运行时是直接调用已经编译好的jax后端代码。

之所以这里没有报错,是因为已经在编译时把neg变量设置为非jax的Traced变量,而是python的False变量,也就是说这时候编译的jax代码中neg变量是直接使用False值的,而不是已变量形式存在的。

当然由于该种方法时把neg变量变为固定值后进行编译的,因此如果neg的值不为False或者输入的x变量的shape或type变化,也会因为没有缓存对应的代码而重新执行编译,例如:

改变x的shape,引起jit的重新编译:

Jax计算框架的JIT编译的static特性_python_05

改变neg的值为True,引起jit的重新编译:

Jax计算框架的JIT编译的static特性_缓存_06