Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么

时间:2024-01-23 18:34:52





给出一个jax的jit的循环结构代码:

from jax import jit, random
import jax.numpy as jnp

from functools import partial


@partial(jit, static_argnums=(2,))
def f(x, y, z):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  print(f"  z = {z}")
  for _ in range(z):
    y = jnp.dot(x + 0.0001, y + 0.0001)
  print(f"  result = {y}")
  return y

key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 10000


运行时间:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_python


这里需要注意,上面的运算时间并没有包括jit的编译时间,只是编译后的jax的后端代码的运行时间。


执行编译后发现很快的有打印:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_python_02


CPU的使用率单核心满载,内存占用逐渐增加:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_python_03


Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_python_04


这里需要注意,在CPU单核心满载的同时GPU的负载为空:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_循环结构_05


一段时间后GPU才开始满载:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_循环结构_06


最后的运行结果:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_循环结构_07


对此我们给出解释:

jax的jit编译会将循环结构进行展开编译,这个过程和C++中的inline是很像的,由于这里的循环次数为10000,因此这个训练结构展开编译需要耗费掉一定的时间,而且这个编译的时间要远大于真是代码的运行时间。


为此,我们给出非notebook的环境进行测试:

运行代码:

from jax import jit, random
import jax.numpy as jnp

from functools import partial


@partial(jit, static_argnums=(2,))
def f(x, y, z):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  print(f"  z = {z}")
  for _ in range(z):
    y = jnp.dot(x + 0.0001, y + 0.0001)
  print(f"  result = {y}")
  return y

key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 10000

f(x, y, z).block_until_ready()

运行时间:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_循环结构_08


运行代码:

from jax import jit, random
import jax.numpy as jnp

from functools import partial


@partial(jit, static_argnums=(2,))
def f(x, y, z):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  print(f"  z = {z}")
  for _ in range(z):
    y = jnp.dot(x + 0.0001, y + 0.0001)
  print(f"  result = {y}")
  return y

key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 10000

f(x, y, z).block_until_ready()



# y = random.normal(key, (10000, ))
f(x, y, z).block_until_ready()


运行时间:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_运行时间_09


由此,我们可以估计出代码的真正运行时间为9秒左右,而jax的jit对10000次循环的编译时间为60秒左右。

我们要注意的是由于jax的jit编译的static特性,如果循环次数改变,那么会因为缓存不存在导致重新编译,而每次的编译都需要60秒,而这个编译后的函数的真正运行时间只有9秒左右。

比如下面,我们只对循环次数改为10001,则重新进入了60秒的jit编译时间:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_运行时间_10


如果我们的真实代码中这个循环次数是固定不变的,或者是循环次数较少的情况,那么在jit内部使用循环也是可以接受的,否则像本文中的这个情况,已经是编译时间远高于运行时间了,而且如果不能保证static特性而导致jit函数的多次编译,那么这种在jit内部使用循环的操作不但不能加速运算反而会大大增加运算时间。


如果我们把jit函数内的循环结构拿出来,对上面的代码进行改造,得到下面的代码:

from jax import jit, random
import jax.numpy as jnp

from functools import partial


@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  y = jnp.dot(x + 0.0001, y + 0.0001)
  print(f"  result = {y}")
  return y

key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))

def run():
  global y
  for _ in range(10000):
    y = f(x, y)


运行时间:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_循环结构_11


为了避免编译时间对总体运行的影响,我们给出文件形式的运行:

代码:

from jax import jit, random
import jax.numpy as jnp

from functools import partial


@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  y = jnp.dot(x + 0.0001, y + 0.0001)
  print(f"  result = {y}")
  return y

key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))

def run():
  global y
  for _ in range(10000):
    y = f(x, y)


for _ in range(100):
  run()


再次测试运行时间:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_python_12


Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么_循环结构_13


可以看到即使把编译时间考虑进去,平均后得到的运行时间依旧为9.46秒,这个运算时长和之前的把训练结构放进jax的jit函数内编译完成后的运行时间保持基本一致,也就是说不把循环控制放入jit中而是用python控制循环结构,然后jax做具体计算,最后的用时是一致的,并且还省掉了之前的大量的jit编译时间(60秒左右),可以说这种方式更加优化。

尽可能的不在jax的jit中进行循环控制是性能更优的一种保证,并且这样的设置会获得几乎一致的计算性能,并且还避免了循环结构被jit编译导致的各种问题,这个逻辑和pytorch的逻辑很像,也就是说计算框架只负责矩阵等线性代数算法的加速,而把逻辑判断和循环控制尽可能的交给python进行,这样不仅不会影响算法的整体性能还会简化算法的逻辑,使代码更稳定。