给出一个jax的jit的循环结构代码:
from jax import jit, random
import jax.numpy as jnp
from functools import partial
@partial(jit, static_argnums=(2,))
def f(x, y, z):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
print(f" z = {z}")
for _ in range(z):
y = jnp.dot(x + 0.0001, y + 0.0001)
print(f" result = {y}")
return y
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 10000
运行时间:
这里需要注意,上面的运算时间并没有包括jit的编译时间,只是编译后的jax的后端代码的运行时间。
执行编译后发现很快的有打印:
CPU的使用率单核心满载,内存占用逐渐增加:
这里需要注意,在CPU单核心满载的同时GPU的负载为空:
一段时间后GPU才开始满载:
最后的运行结果:
对此我们给出解释:
jax的jit编译会将循环结构进行展开编译,这个过程和C++中的inline是很像的,由于这里的循环次数为10000,因此这个训练结构展开编译需要耗费掉一定的时间,而且这个编译的时间要远大于真是代码的运行时间。
为此,我们给出非notebook的环境进行测试:
运行代码:
from jax import jit, random
import jax.numpy as jnp
from functools import partial
@partial(jit, static_argnums=(2,))
def f(x, y, z):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
print(f" z = {z}")
for _ in range(z):
y = jnp.dot(x + 0.0001, y + 0.0001)
print(f" result = {y}")
return y
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 10000
f(x, y, z).block_until_ready()
运行时间:
运行代码:
from jax import jit, random
import jax.numpy as jnp
from functools import partial
@partial(jit, static_argnums=(2,))
def f(x, y, z):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
print(f" z = {z}")
for _ in range(z):
y = jnp.dot(x + 0.0001, y + 0.0001)
print(f" result = {y}")
return y
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 10000
f(x, y, z).block_until_ready()
# y = random.normal(key, (10000, ))
f(x, y, z).block_until_ready()
运行时间:
由此,我们可以估计出代码的真正运行时间为9秒左右,而jax的jit对10000次循环的编译时间为60秒左右。
我们要注意的是由于jax的jit编译的static特性,如果循环次数改变,那么会因为缓存不存在导致重新编译,而每次的编译都需要60秒,而这个编译后的函数的真正运行时间只有9秒左右。
比如下面,我们只对循环次数改为10001,则重新进入了60秒的jit编译时间:
如果我们的真实代码中这个循环次数是固定不变的,或者是循环次数较少的情况,那么在jit内部使用循环也是可以接受的,否则像本文中的这个情况,已经是编译时间远高于运行时间了,而且如果不能保证static特性而导致jit函数的多次编译,那么这种在jit内部使用循环的操作不但不能加速运算反而会大大增加运算时间。
如果我们把jit函数内的循环结构拿出来,对上面的代码进行改造,得到下面的代码:
from jax import jit, random
import jax.numpy as jnp
from functools import partial
@jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
y = jnp.dot(x + 0.0001, y + 0.0001)
print(f" result = {y}")
return y
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
def run():
global y
for _ in range(10000):
y = f(x, y)
运行时间:
为了避免编译时间对总体运行的影响,我们给出文件形式的运行:
代码:
from jax import jit, random
import jax.numpy as jnp
from functools import partial
@jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
y = jnp.dot(x + 0.0001, y + 0.0001)
print(f" result = {y}")
return y
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
def run():
global y
for _ in range(10000):
y = f(x, y)
for _ in range(100):
run()
再次测试运行时间:
可以看到即使把编译时间考虑进去,平均后得到的运行时间依旧为9.46秒,这个运算时长和之前的把训练结构放进jax的jit函数内编译完成后的运行时间保持基本一致,也就是说不把循环控制放入jit中而是用python控制循环结构,然后jax做具体计算,最后的用时是一致的,并且还省掉了之前的大量的jit编译时间(60秒左右),可以说这种方式更加优化。
尽可能的不在jax的jit中进行循环控制是性能更优的一种保证,并且这样的设置会获得几乎一致的计算性能,并且还避免了循环结构被jit编译导致的各种问题,这个逻辑和pytorch的逻辑很像,也就是说计算框架只负责矩阵等线性代数算法的加速,而把逻辑判断和循环控制尽可能的交给python进行,这样不仅不会影响算法的整体性能还会简化算法的逻辑,使代码更稳定。