Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么(续)

时间:2024-01-23 14:56:03




从前文我们知道,jax的jit中尽可能的不要放入循环结构,因为在jit编译时会将循环结构暂开,因而会消耗掉大量的时间进行编译。

如果我们将代码中的循环次数设置为30000呢,代码如下:

from jax import jit, random
import jax.numpy as jnp

from functools import partial


@partial(jit, static_argnums=(2,))
def f(x, y, z):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  print(f"  z = {z}")
  for _ in range(z):
    y = jnp.dot(x + 0.0001, y + 0.0001)
  print(f"  result = {y}")
  return y

key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 30000


代码的jit编译后的CPU的情况:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么(续)_运行时间

可以看到,其编译后的CPU端内存占用为120GB*0.335=40.2GB,也就是说在循环次数为30000次的情况下编译jit的内存占用就是40GB。


使用notebook运行:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么(续)_循环结构_02


可以看到这里展示的时间28.8秒依旧是编译后的运行时间,给出单文件的运行:

代码:

from jax import jit, random
import jax.numpy as jnp

from functools import partial


@partial(jit, static_argnums=(2,))
def f(x, y, z):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  print(f"  z = {z}")
  for _ in range(z):
    y = jnp.dot(x + 0.0001, y + 0.0001)
  print(f"  result = {y}")
  return y

key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 30000

f(x, y, z).block_until_ready()


运行结果:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么(续)_python实现_03


Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么(续)_python实现_04


同样将循环结构从jit函数内提出来,使用python实现循环结构:

代码:

from jax import jit, random
import jax.numpy as jnp

from functools import partial


@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  y = jnp.dot(x + 0.0001, y + 0.0001)
  print(f"  result = {y}")
  return y

key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))


z = 30000

for _ in range(30000):
  y = f(x, y)


运行表现:

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么(续)_运行时间_05

Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么(续)_python实现_06


可以看到,把循环结构从jit函数内提出来可以避免过长的jit编译时间,提高运算性能。