JAX是一个为高性能数值计算设计的Python库,特别是机器学习研究。它通过使用GPU来加速Python和NumPy代码。
JAX在机器学习领域崭露头角,其野心是使机器学习变得简单而高效。虽然,JAX仍然是谷歌和Deepmind的研究项目,还不是谷歌的官方产品,但已经被内部广泛使用,并被外部的ML研究人员采用。我们想提供一个关于JAX的介绍,如何安装JAX,以及它的优势和能力。
什么是机器学习的JAX?
JAX是一个为高性能数值计算设计的Python库,特别是机器学习研究。它的数值函数的API是基于NumPy的,这是一个用于科学计算的函数集合。JAX专注于通过使用XLA在GPU上编译NumPy函数来加速机器学习过程,并使用autograd来区分Python和NumPy函数,以及基于梯度的优化。JAX能够通过循环、分支、递归和闭合进行分化,并利用GPU加速轻松地获取导数的导数。JAX还支持反向传播和正向模式的微分。
当使用GPU运行你的代码时,JAX提供了卓越的性能,还有一个及时编译(JIT)选项,可以轻松加快大型项目的速度,我们将在本文后面深入探讨这个问题。
把JAX看作是一个Python库,它通过函数转换来修改NumPy和Python代码,以实现加速的机器学习。一般来说,只要你打算用GPU进行训练,计算梯度(autograd),或者使用JIT代码编译,都应该使用JAX。
为什么使用JAX?
除了与普通的CPU一起工作外,JAX的主要功能是能够与不同的处理单元(如GPU)一起完全发挥作用。这使得JAX与类似的软件包相比具有很大的优势,因为在涉及到图像和矢量处理时,使用GPU并行化可以使性能比CPU更快。
这一点极为重要,因为在使用NumPy库时,用户可以建立特殊大小的矩阵,使GPU在处理这类数据格式时更有时间效率。
这个时间差使得JAX库的速度和性能通过几个关键的实现超过了NumPy本身100倍以上。
矢量化--将多个数据作为单一指令处理,为线性代数计算和机器学习提供了巨大的速度。
代码并行化--在单个处理器上运行的串行代码,并将其分发出去的过程。这里首选GPU,因为它们有许多专门用于计算的处理器。
自动微分--非常简单和直接的微分,可以多次串联,轻松地评估高阶导数。
如何安装JAX
要安装只有CPU版本的JAX,这对于在笔记本电脑上进行本地开发可能是有用的,你可以运行
在Linux上,通常需要先将pip更新到支持manylinux2014轮的版本。
pip安装GPU (CUDA)
要安装支持CPU和NVIDIA GPU的JAX,你必须先安装CUDA和CuDNN,如果它们还没有被安装。与许多其他流行的深度学习系统不同,JAX并没有将CUDA或CuDNN作为pip包的一部分来捆绑。
JAX只为Linux提供预建的兼容CUDA的*,包括CUDA 11.1或更新版本,以及CuDNN 8.0.5或更新版本。其他操作系统、CUDA和CuDNN的组合也是可能的,但需要从源代码中构建。
需要CUDA 11.1或更新版本
如果你从源码构建,你可能会使用更早的CUDA版本,但是所有11.1以上的CUDA版本都有已知的错误,所以我们不会为旧的CUDA版本提供预构建的二进制文件。
预置*支持的cuDNN版本是:
cuDNN 8.2或更新版本。如果你的cuDNN安装得足够新,我们建议使用cuDNN 8.2*,因为它支持额外的功能。
cuDNN 8.0.5或更新版本。
您必须使用至少与您的CUDA工具箱对应的驱动版本一样新的NVIDIA驱动版本。例如,如果你安装了CUDA 11.4 update 4,如果在Linux上,你必须使用NVIDIA驱动470.82.01或更新版本。这是一个严格的要求,它的存在是因为JAX依赖于JIT-compiling代码;旧的驱动程序可能会导致失败。
如果你需要使用较新的CUDA工具包和较旧的驱动程序,例如在一个不能轻易更新NVIDIA驱动程序的集群上,你也许可以使用NVIDIA为此提供的CUDA向前兼容包。
jaxlib的版本必须与你要使用的现有CUDA安装的版本相对应。你可以为jaxlib明确指定一个特定的CUDA和CuDNN版本。
你可以用命令找到你的CUDA版本:
比较JAX和NumPy
由于JAX是一个增强的NumPy,它们的语法非常相似,使用户有能力在NumPy或JAX不执行的项目中交替使用这两种方法。这通常是在较小的项目中,加速的数量在节省的时间上是可以忽略不计的。然而,随着模型越来越大,你越应该考虑JAX。
使用JAX与NumPy进行两个矩阵的乘法运算
为了清楚地说明这两个库的速度差异,我们将使用这两个库将两个矩阵相乘,然后检查仅CPU和GPU的性能差异。我们还将检查由JIT编译器引起的性能提升。
为了继续学习本教程,请安装并导入JAX和NumPy库(来自前一步)。你可以在Kaggle或Google Colab等网站上测试你的代码。与任何库一样,你应该在代码的开头写上以下几行来导入JAX。
你也可以用类似的方式导入NumPy库:
接下来,我们将使用CPU和GPU比较JAX和Numpy的性能,在Python中把两个矩阵相乘。对于这些基准测试,越低越好。
CPU上的NumPy
首先,我们将使用NumPy创建一个5,000乘以5,000的矩阵,并测试其速度方面的性能。
每循环785毫秒
在NumPy上运行的代码的单次循环,每次循环的时间约为750毫秒。
CPU上的JAX
现在让我们运行同样的代码,但这次是使用JAX库。
每个循环1.43秒
正如你所看到的,比较JAX和NumPy的纯CPU性能表明,NumPy是更快的选择。虽然JAX在普通的CPU上可能无法提供最好的性能,但它在GPU上确实提供了更好的性能。
使用GPU的JAX
现在,让我们尝试创建同样的5,000乘5,000的矩阵,这次使用JAX与GPU而不是普通CPU。
每循环80.6毫秒
正如清楚显示的那样,当在GPU上而不是CPU上运行JAX时,我们实现了更好的时间,每循环约80ms(大约15倍的性能)。当使用更大的矩阵或时间尺度时,这将更容易看到。
及时编译(JIT)
使用jit命令,我们的代码将使用特定的XLA编译器进行编译,使我们的函数能够有效执行。
XLA是加速线性代数的简称,被JAX和Tensorflow等库用来在GPU上编译和运行代码,效率更高。因此,总结起来,XLA是一个特定的线性代数编译器,能够以更高的速度编译代码。
我们将使用selu_np函数测试我们的代码,该函数代表缩放指数线性单元,并检查NumPy在普通CPU上的不同时间表现,以及在GPU上用JIT运行JAX。
CPU上的NumPy
首先,我们将使用NumPy库创建一个大小为1,000,000的向量。
每循环8.3毫秒
GPU上的JAX与JIT
现在我们将在GPU上使用JAX和JIT来测试我们的代码。
每个循环153微秒(每个循环0.153毫秒)
最后,当使用JIT编译器和GPU时,我们得到了比使用普通GPU更好的性能。正如你可以清楚地看到,差异是非常明显的,从NumPy到使用JIT的JAX,速度提高了近5000%,或者说是50倍!
把JAX看成是对NumPy的修改,以实现用GPU加速机器学习。由于NumPy只能在CPU上编译,如果你选择在GPU上执行代码,JAX就比NumPy快。作为一般规则,只要计划在GPU上使用NumPy或使用JIT代码编译,就应该使用JAX。
JAX的局限性:纯函数
JAX 转换和复杂化是为功能纯正的 Python 函数设计的。纯函数不能通过访问外部变量来改变程序的状态,也不能对诸如 print() 这样的输入/输出流函数产生副作用。
连续的运行会导致这些副作用不能按预期执行。如果你不小心,未被追踪的副作用可能会使你的预期计算的准确性受到影响。
使用谷歌的JAX
在这篇文章中,我们解释了JAX的功能以及它给NumPy带来的优势。我们介绍了如何安装JAX库以及它对机器学习的优势。
然后我们继续导入JAX和NumPy。此外,我们将JAX与NumPy(这是最著名的竞争库)进行了比较,并揭示了这两者之间的时间和性能差异,使用普通的CPU和GPU以及一些JIT测试,看到了速度的大幅提高。
如果你是一个高级机器/深度学习从业者,那么在你的武器库中添加一个像JAX这样的库,它的(GPU/TPU)加速器和它的高效JIT编译器肯定会让你的生活变得更加轻松。