mpi4jax:MPI向Jax致意并加快了速度

时间:2024-04-06 12:49:23
【文件属性】:

文件名称:mpi4jax:MPI向Jax致意并加快了速度

文件大小:4.54MB

文件格式:ZIP

更新时间:2024-04-06 12:49:23

mpi jit jax xla Python

mpi4jax mpi4jax支持阵列的零复制,多主机通信,甚至可以通过固定代码和GPU内存进行通信。 但为什么? JAX框架,但是其仍然受到限制。 使用mpi4jax ,您可以将基于JAX的模拟扩展到整个CPU和GPU集群(无需离开jax.jit )。 本着差异化编程的精神, mpi4jax还支持通过一些MPI操作进行差异化。 快速安装 mpi4jax可通过pip和conda : $ pip install mpi4jax # Pip $ conda install -c conda-forge mpi4jax # conda 我们的文档包括一些更高级的安装示例。 用法示例 from mpi4py import MPI import jax import jax . numpy as jnp import mpi4jax comm


【文件预览】:
mpi4jax-master
----README.rst(3KB)
----.readthedocs.yml(193B)
----pyproject.toml(601B)
----conf()
--------travis-install-mpi.sh(4KB)
--------ompi_rootenv.patch(1KB)
----docs()
--------shallow-water.rst(4KB)
--------api.rst(601B)
--------environment.yml(193B)
--------conf.py(2KB)
--------usage.rst(3KB)
--------make.bat(795B)
--------installation.rst(3KB)
--------sharp-bits.rst(4KB)
--------shallow-water-source.rst(241B)
--------Makefile(634B)
--------index.rst(235B)
--------_static()
----.github()
--------workflows()
----tests()
--------test_decorators.py(1KB)
--------test_examples.py(494B)
--------test_deprecations.py(2KB)
--------test_invalid_jaxlib.py(361B)
--------conftest.py(376B)
--------test_validation.py(2KB)
--------test_flush.py(142B)
--------collective_ops()
----mpi4jax()
--------_src()
--------__init__.py(735B)
--------_deprecations.py(1KB)
----.git-archival.txt(23B)
----examples()
--------shallow_water.py(16KB)
----setup.py(6KB)
----LICENSE.md(1KB)
----.gitignore(2KB)
----.flake8(163B)
----.gitattributes(46B)
----.pre-commit-config.yaml(207B)

网友评论