从前文我们知道,jax的jit中尽可能的不要放入循环结构,因为在jit编译时会将循环结构暂开,因而会消耗掉大量的时间进行编译。
如果我们将代码中的循环次数设置为30000呢,代码如下:
from jax import jit, random
import jax.numpy as jnp
from functools import partial
@partial(jit, static_argnums=(2,))
def f(x, y, z):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
print(f" z = {z}")
for _ in range(z):
y = jnp.dot(x + 0.0001, y + 0.0001)
print(f" result = {y}")
return y
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 30000
代码的jit编译后的CPU的情况:
可以看到,其编译后的CPU端内存占用为120GB*0.335=40.2GB,也就是说在循环次数为30000次的情况下编译jit的内存占用就是40GB。
使用notebook运行:
可以看到这里展示的时间28.8秒依旧是编译后的运行时间,给出单文件的运行:
代码:
from jax import jit, random
import jax.numpy as jnp
from functools import partial
@partial(jit, static_argnums=(2,))
def f(x, y, z):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
print(f" z = {z}")
for _ in range(z):
y = jnp.dot(x + 0.0001, y + 0.0001)
print(f" result = {y}")
return y
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 30000
f(x, y, z).block_until_ready()
运行结果:
同样将循环结构从jit函数内提出来,使用python实现循环结构:
代码:
from jax import jit, random
import jax.numpy as jnp
from functools import partial
@jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
y = jnp.dot(x + 0.0001, y + 0.0001)
print(f" result = {y}")
return y
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 30000
for _ in range(30000):
y = f(x, y)
运行表现:
可以看到,把循环结构从jit函数内提出来可以避免过长的jit编译时间,提高运算性能。