Software:Google JAX
Logo | |
Developer(s) | |
---|---|
Preview release | v0.3.13
/ 16 May 2022 |
Repository | github |
Written in | Python, C++ |
Operating system | Linux, macOS, Windows |
Platform | Python, NumPy |
Size | 9.0 MB |
Type | Machine learning |
License | Apache 2.0 |
Google JAX is a machine learning framework for transforming numerical functions.[1][2][3] It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[4][5] The primary functions of JAX are:[1]
- grad: automatic differentiation
- jit: compilation
- vmap: auto-vectorization
- pmap: SPMD programming
grad
The code below demonstrates the grad
function's automatic differentiation.
# imports from jax import grad import jax.numpy as jnp # define the logistic function def logistic(x): return jnp.exp(x) / (jnp.exp(x) + 1) # obtain the gradient function of the logistic function grad_logistic = grad(logistic) # evaluate the gradient of the logistic function at x = 1 grad_log_out = grad_logistic(1.0) print(grad_log_out)
The final line should outputː
0.19661194
jit
The code below demonstrates the jit function's optimization through fusion.
# imports from jax import jit import jax.numpy as jnp # define the cube function def cube(x): return x * x * x # generate data x = jnp.ones((10000, 10000)) # create the jit version of the cube function jit_cube = jit(cube) # apply the cube and jit_cube functions to the same data for speed comparison cube(x) jit_cube(x)
The computation time for jit_cube
(line no. 17) should be noticeably shorter than that for cube
(line no. 16). Increasing the values on line no. 7, will increase the difference.
vmap
The code below demonstrates the vmap
function's vectorization.
# imports from functools import partial from jax import vmap import jax.numpy as jnp # define function def grads(self, inputs): in_grad_partial = partial(self._net_grads, self._net_params) grad_vmap = vmap(in_grad_partial) rich_grads = grad_vmap(inputs) flat_grads = np.asarray(self._flatten_batch(rich_grads)) assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0] return flat_grads
The GIF on the right of this section illustrates the notion of vectorized addition.
pmap
The code below demonstrates the pmap
function's parallelization for matrix multiplication.
# import pmap and random from JAX; import JAX NumPy from jax import pmap, random import jax.numpy as jnp # generate 2 random matrices of dimensions 5000 x 6000, one per device random_keys = random.split(random.PRNGKey(0), 2) matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys) # without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices) # without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately means = pmap(jnp.mean)(outputs) print(means)
The final line should print the valuesː
[1.1566595 1.1805978]
Libraries using JAX
Several python libraries use JAX as a backend, including:
- Flax, a high level neural network library initially developed by Google Brain.[6]
- Equinox, a library that revolves around the idea of representing parameterised functions (including neural networks) as PyTrees. It was created by Patrick Kidger.[7]
- Diffrax, a library for the numerical solution of differential equations, such as ordinary differential equations and stochastic differential equations.[8]
- Optax, a library for gradient processing and optimisation developed by DeepMind.[9]
- Lineax, a library for numerically solving linear systems and linear least squares.[10]
- RLax, a library for developing reinforcement learning agents developed by DeepMind.[11]
- jraph, a library for graph neural networks, developed by DeepMind.[12]
- jaxtyping, a library for adding type annotations[13] for the shape and data type ("dtype") of arrays or tensors.[14]
Some R libraries use JAX as a backend as well, including:
- fastrerandomize, a library that uses the linear-algebra optimized compiler in JAX to speed up selection of balanced randomizations in a design of experiments procedure known as rerandomization.[15]
See also
- NumPy
- TensorFlow
- PyTorch
- CUDA
- Automatic differentiation
- Just-in-time compilation
- Vectorization
- Automatic parallelization
- Accelerated Linear Algebra
External links
- Documentationː jax
.readthedocs .io - Colab (Jupyter/IPython) Quickstart Guideː colab
.research .google .com /github /google /jax /blob /main /docs /notebooks /quickstart .ipynb - TensorFlow's XLAː www
.tensorflow .org /xla (Accelerated Linear Algebra) - Original paperː mlsys
.org /Conferences /doc /2018 /146 .pdf
References
- ↑ 1.0 1.1 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam et al. (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library (Google), Bibcode: 2021ascl.soft11002B, https://github.com/google/jax, retrieved 2022-06-18
- ↑ Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing". MLsys: 1–3. https://mlsys.org/Conferences/doc/2018/146.pdf.
- ↑ "Using JAX to accelerate our research" (in en). https://www.deepmind.com/blog/using-jax-to-accelerate-our-research.
- ↑ Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta" (in en-US). https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6.
- ↑ "Why is Google's JAX so popular?" (in en-US). 2022-04-25. https://analyticsindiamag.com/why-is-googles-jax-so-popular/.
- ↑ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, https://github.com/google/flax, retrieved 2022-07-29
- ↑ Kidger, Patrick (2022-07-29), Equinox, https://github.com/patrick-kidger/equinox, retrieved 2022-07-29
- ↑ Kidger, Patrick (2023-08-05), Diffrax, https://github.com/patrick-kidger/diffrax, retrieved 2023-08-08
- ↑ Optax, DeepMind, 2022-07-28, https://github.com/deepmind/optax, retrieved 2022-07-29
- ↑ Lineax, Google, 2023-08-08, https://github.com/google/lineax, retrieved 2023-08-08
- ↑ RLax, DeepMind, 2022-07-29, https://github.com/deepmind/rlax, retrieved 2022-07-29
- ↑ Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08, https://github.com/deepmind/jraph, retrieved 2023-08-08
- ↑ "typing — Support for type hints". https://docs.python.org/3/library/typing.html.
- ↑ jaxtyping, Google, 2023-08-08, https://github.com/google/jaxtyping, retrieved 2023-08-08
- ↑ Jerzak, Connor (2023-10-01), fastrerandomize, https://github.com/cjerzak/fastrerandomize-software, retrieved 2023-10-03