import numpy as np
x_np = np.linspace(0, 10, 101)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np)
Lecture 25
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
JAX provides a NumPy-inspired interface for convenience (
jax.numpy
), can often be used as drop-in replacementAll JAX operations are implemented in terms of operations in XLA (Accelerated Linear Algebra compiler)
Supports sequential execution or JIT compilation
Updated autograd which can be used with native Python and NumPy functions
array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
0.8, 0.9, 1. , 1.1, 1.2, 1.3, 1.4, 1.5,
1.6, 1.7, 1.8, 1.9, 2. , 2.1, 2.2, 2.3,
2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1,
3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9,
4. , 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7,
4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4, 5.5,
5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3,
6.4, 6.5, 6.6, 6.7, 6.8, 6.9, 7. , 7.1,
7.2, 7.3, 7.4, 7.5, 7.6, 7.7, 7.8, 7.9,
8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,
8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5,
9.6, 9.7, 9.8, 9.9, 10. ])
Array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
0.8, 0.9, 1. , 1.1, 1.2, 1.3, 1.4, 1.5,
1.6, 1.7, 1.8, 1.9, 2. , 2.1, 2.2, 2.3,
2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1,
3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9,
4. , 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7,
4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4, 5.5,
5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3,
6.4, 6.5, 6.6, 6.7, 6.8, 6.9, 7. , 7.1,
7.2, 7.3, 7.4, 7.5, 7.6, 7.7, 7.8, 7.9,
8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,
8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5,
9.6, 9.7, 9.8, 9.9, 10. ], dtype=float32)
Array([ 0. , 0.19867, 0.38942, 0.56464, 0.71736,
0.84147, 0.93204, 0.98545, 0.99957, 0.97385,
0.9093 , 0.8085 , 0.67546, 0.5155 , 0.33499,
0.14112, -0.05837, -0.25554, -0.44252, -0.61186,
-0.7568 , -0.87158, -0.9516 , -0.99369, -0.99616,
-0.95892, -0.88345, -0.77276, -0.63127, -0.4646 ,
-0.27942, -0.08309, 0.11655, 0.31154, 0.49411,
0.65699, 0.79367, 0.89871, 0.96792, 0.99854,
0.98936, 0.94073, 0.8546 , 0.7344 , 0.58492,
0.41212, 0.22289, 0.02478, -0.17433, -0.36648,
-0.54402, -0.69987, -0.82783, -0.92278, -0.98094,
-0.99999, -0.97918, -0.91933, -0.82283, -0.69353,
-0.53657, -0.35823, -0.1656 , 0.03362, 0.23151,
0.42017, 0.59207, 0.74038, 0.85916, 0.9437 ,
0.99061, 0.99803, 0.96566, 0.89479, 0.78825,
0.65029, 0.4864 , 0.30312, 0.10775, -0.09191,
-0.2879 , -0.47242, -0.63811, -0.77835, -0.88757,
-0.9614 , -0.9969 , -0.99266, -0.94885, -0.8672 ,
-0.75099, -0.60483, -0.43457, -0.24698, -0.04954,
0.14988, 0.34331, 0.52307, 0.68196, 0.81367,
0.91295], dtype=float32)
Pseudo random number generation in JAX is a bit different than with NumPy - the latter depends on a global state that is updated each time a random function is called.
NumPy’s PRNG guarantees something called sequential equivalence which amounts to sampling N numbers sequentially is the same as sampling N numbers at once (e.g. a vector of length N).
Sequantial equivalence can be problematic in light of parallelization, consider the following code:
How do we guarantee that we get consistent results if we don’t know the order that bar()
and baz()
will run?
JAX makes use of ’random keys` which are just a fancier version of random seeds - all of JAX’s random functions require that a key be passed in.
Since a key is essentially a seed we do not want to reuse them (unless we want an identical output). Therefore to generate multiple different PRN we can split a key to deterministically generate two (or more) new keys.
4.38 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[1000000])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function SELU_np at /var/folders/ds/8sqz2v4d355btthn6r88kdc00000gn/T/ipykernel_75333/762427251.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
When it works the jit tool is fantastic, but it does have a number of limitations,
Must use pure functions (no side effects)
Must primarily use JAX functions
jnp.minimum()
not np.minimum()
or min()
Must generally avoid conditionals / control flow
Issues around concrete values when tracing (static values)
Check performance - there are not always gains + there is the initial cost of compilation
Like with torch, the grad()
function takes a numerical function returning a scalar and returns a function for calculating the gradient of that function.
vmap()
I would like to plot h()
and jax.grad(h)()
- lets see what happens,
y x1 x2 x3 x4 x5
0 -0.151710 0.353658 1.633932 0.553257 1.415731 A
1 3.579895 1.311354 1.457500 0.072879 0.330330 B
2 0.768329 -0.744034 0.710362 -0.246941 0.008825 B
3 7.788646 0.806624 -0.228695 0.408348 -2.481624 B
4 1.394327 0.837430 -1.091535 -0.860979 -0.810492 A
.. ... ... ... ... ... ..
495 -0.204932 -0.385814 -0.130371 -0.046242 0.004914 A
496 0.541988 0.845885 0.045291 0.171596 0.332869 A
497 -1.402627 -1.071672 -1.716487 -0.319496 -1.163740 C
498 -0.043645 1.744800 -0.010161 0.422594 0.772606 A
499 -1.550276 0.910775 -1.675396 1.921238 -0.232189 B
[500 rows x 6 columns]
def model(b, X=X):
return X @ b
def reg_loss(b, λ=0., X=X, y=y, model=model):
return jnp.mean((y - model(b,X).squeeze())**2)
def ridge_loss(b, λ=0., X=X, y=y, model=model):
return jnp.mean((y - model(b,X).squeeze())**2) + λ * jnp.sum(b**2)
def lasso_loss(b, λ=0., X=X, y=y, model=model):
return jnp.mean((y - model(b,X).squeeze())**2) + λ * jnp.sum(jnp.abs(b))
fit_jit = jax.jit(fit, static_argnames=["loss","λ","n","X","y","model"])
b_hat = fit_jit(b, reg_loss)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
The error occurred while tracing the function fit at /var/folders/ds/8sqz2v4d355btthn6r88kdc00000gn/T/ipykernel_75333/3842046654.py:1 for jit. This value became a tracer due to JAX operations on these lines:
operation a[35m:f32[500,1][39m = dot_general[dimension_numbers=(([1], [0]), ([], []))] b c
from line /var/folders/ds/8sqz2v4d355btthn6r88kdc00000gn/T/ipykernel_75333/530544497.py:2 (model)
operation a[35m:f32[500][39m = sub b c
from line /var/folders/ds/8sqz2v4d355btthn6r88kdc00000gn/T/ipykernel_75333/530544497.py:5 (reg_loss)
operation a[35m:f32[][39m = div b c
from line /var/folders/ds/8sqz2v4d355btthn6r88kdc00000gn/T/ipykernel_75333/530544497.py:5 (reg_loss)
operation a[35m:f32[][39m = div b c
from line /var/folders/ds/8sqz2v4d355btthn6r88kdc00000gn/T/ipykernel_75333/530544497.py:5 (reg_loss)
operation a[35m:f32[1,8][39m = dot_general[dimension_numbers=(([0], [0]), ([], []))] b c
from line /var/folders/ds/8sqz2v4d355btthn6r88kdc00000gn/T/ipykernel_75333/530544497.py:2 (model)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
442 ms ± 2.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
307 µs ± 2.44 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.
Why do we need them?
In machine learning, some places where you commonly find pytrees are:
Model parameters
Dataset entries
This helps us avoid functions with large argument lists and make it possible to vectorize / map more operations.
JAX provides a number of built-in tools for working with / iterating over pytrees, tree_map()
being the most commonly used,
tree_map()
will iterate and apply the desired function over all of the leaf elements while maintaining the structure of the pytree (similar to rapply()
in R).
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
jax.tree_map(type, example_trees)
[[int, str, object],
(int, (int, int), ()),
[int, {'k1': int, 'k2': (int, int)}, int],
{'a': int, 'b': (int, int)},
jaxlib.xla_extension.ArrayImpl]
def init_params(layer_widths, key):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
key, new_key = jax.random.split(key)
params.append(
dict(
W = jax.random.normal(new_key, shape=(n_in, n_out)) * np.sqrt(2/n_in),
b = jnp.ones(shape=(n_out,))
)
)
return params
key = jax.random.PRNGKey(1234)
params = init_params([1, 128, 128, 1], key)
from functools import partial
class model:
def forward(self, params, x):
*hidden, last = params
for layer in hidden:
x = x @ layer['W'] + layer['b']
x = jax.nn.relu(x)
return x @ last['W'] + last['b']
def loss_fn(self, params, x, y):
return jnp.mean((self.forward(params, x) - y) ** 2)
@partial(jax.jit, static_argnames=['self', 'lr'])
def step(self, params, x, y, lr=0.0001):
grads = jax.grad(self.loss_fn)(params, x, y) # Note that since `params` is a pytree so will `grads`
return jax.tree_map(
lambda p, g: p - lr * g, params, grads
)
def fit(self, params, x, y, n = 1000):
for i in range(n):
params = self.step(params, x, y)
return params
There are a number of other libraries built on top of JAX that provide higher level interfaces for common tasks,
Neural networks (torch-like interfaces)
Bayesian models
Other
Optax - gradient processing and optimization library (DeepMind)
Awesome-JAX - collection of JAX related links and resources
Sta 663 - Spring 2023