First Steps in JAXΒΆ
import jax
import jax.numpy as jnp
import numpy as np
jnp.arange(10)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)