Backends overview#
brainunit pairs a physical Unit with an array mantissa. The mantissa can
live on any one of several array backends, and every unit-aware operation
(brainunit.math, brainunit.linalg, brainunit.fft, plain arithmetic)
dispatches to the matching backend’s array library. You can stay inside one
backend end-to-end, mix them, or switch by calling a single conversion
method.
This page describes the architecture, the supported backends, how selection works, what each backend can and cannot do, and how to install optional backend dependencies. Per-backend notebooks follow it: see JAX, NumPy, CuPy, PyTorch, Dask, and ndonnx.
Supported backends#
Backend |
Mantissa type |
Optional install |
Typical use case |
|---|---|---|---|
|
|
required (core) |
autograd, JIT, vmap, accelerators (default) |
|
|
required (core) |
scipy / pandas / sklearn interop, CPU |
|
|
|
NVIDIA GPU arrays, drop-in NumPy replacement |
|
|
|
PyTorch models, CUDA/MPS tensors |
|
|
|
out-of-core / parallel arrays, lazy compute |
|
|
|
symbolic graph building, ONNX export |
jax and numpy are always available because both are required core
dependencies. The other four are opt-in: if you do not install the extra,
brainunit still works — it just refuses to dispatch onto that backend and
raises brainunit.BackendError with the matching pip install hint when you
ask for one explicitly.
Internally, the numpy, cupy, torch, and dask namespaces are sourced
from
array_api_compat so they
all expose the same array-API-standard surface. jax.numpy (JAX ≥ 0.9) and
ndonnx are array-API compatible on their own and are used unwrapped.
How backend selection works#
For every operation, brainunit asks: “which array library should compute the result?” The rule:
Inspect the input mantissas. If exactly one backend kind is present, use it.
If inputs mix backends, or there are no array inputs, consult the thread-local default backend set by
set_default_backend(...)/using_backend(...).If no default is set, fall back to
jax(the historical default).
import numpy as np
import jax.numpy as jnp
import brainunit as u
q_np = u.Quantity(np.array([1.0]), unit=u.meter)
q_jax = u.Quantity(jnp.array([2.0]), unit=u.meter)
print('q_np.backend =', q_np.backend)
print('q_jax.backend =', q_jax.backend)
print('(q_np + q_np).bk =', (q_np + q_np).backend) # single -> wins
print('(q_np + q_jax).bk =', (q_np + q_jax).backend) # mixed -> default
q_np.backend = numpy
q_jax.backend = jax
(q_np + q_np).bk = numpy
(q_np + q_jax).bk = jax
Override the tiebreaker with the context manager:
with u.using_backend('numpy'):
print('inside using_backend:', (q_np + q_jax).backend) # 'numpy'
print('outside:', (q_np + q_jax).backend) # back to 'jax'
inside using_backend: jax
outside: jax
Or set it for the rest of the program:
u.set_default_backend('numpy')
print(u.get_default_backend())
print((q_np + q_jax).backend)
u.set_default_backend(None) # restore default
print(u.get_default_backend())
numpy
jax
None
The default is a ContextVar, so it isolates per-thread and per-task; nested
using_backend(...) blocks restore the prior value on exit.
Choosing a backend#
There is no universally best backend — each one trades capability against ecosystem.
jax— pick this when you need automatic differentiation, JIT,vmap, or accelerator support out of the box. This is the default and the most fully integrated backend; everything inbrainunit.autograd,brainunit.lax, andbrainunit.sparserequires it.numpy— pick this for interop with the broader scientific Python stack (scipy, pandas, sklearn, matplotlib) where you want eager results with no JAX tracing. Works on CPU only.cupy— pick this when you want a near-drop-in NumPy replacement running on an NVIDIA GPU and you don’t need autodiff. Requires a CUDA toolkit.torch— pick this to embed unit-aware computations inside an existing PyTorch model. PyTorch’s own autograd is preserved through brainunit ops, soloss.backward()works on a quantity-derived loss.brainunit.autograditself is JAX-only — calltorch.autograd.gradon the mantissa.dask— pick this for arrays that don’t fit in memory, or for embarrassingly parallel array work on a cluster. Operations stay lazy until you call.compute().ndonnx— pick this when you want to build an ONNX graph symbolically. Operations build the graph rather than executing eagerly. Still maturing: not every brainunit operation has an ndonnx implementation.
Backend capabilities and limitations#
Dimensional analysis works on every backend — brainunit tracks units on the
Python Quantity object, independent of the mantissa library. The
limitations below describe what each array backend can and cannot do, not
the unit system.
jax (default)#
Full feature set. The only backend that supports:
brainunit.lax.*— wrappers overjax.laxprimitives.brainunit.autograd.*—grad,jacobian,hessian.brainunit.sparse.*—CSR,CSC,COOsparse matrices.jax.jit,jax.vmap,jax.pmapover quantities.
numpy#
Eager CPU computation. brainunit.math, brainunit.linalg, and brainunit.fft
all work. JAX-specific subpackages raise BackendError.
cupy#
NVIDIA GPU arrays via CUDA. Same general capability as numpy for
brainunit.math / brainunit.linalg / brainunit.fft, but executed on the GPU.
No autograd, no JIT, no brainunit.lax.
torch#
PyTorch tensors. brainunit.math / brainunit.linalg / brainunit.fft route
through array_api_compat.torch. Use torch.autograd.grad on the
mantissa when you need backward passes — brainunit.autograd is JAX-only.
dask#
Lazy arrays. Building a quantity, inspecting .shape / .ndim / .dtype,
arithmetic, and most array-API operations stay lazy. Operations that need a
concrete Python value — float(q), int(q), q.tolist(), np.asarray(q),
hash(q), operator.index(q) — raise BackendError; call
q.mantissa.compute() first.
ndonnx#
Symbolic / ONNX graph building. Routing is correct for the array-API
operations that ndonnx implements. Operations ndonnx hasn’t implemented yet
surface their own errors unwrapped (brainunit does not catch them). Unit
information lives on the Quantity and is not encoded in the ONNX graph.
Example of a JAX-only operation refusing a NumPy mantissa:
from brainunit import BackendError
q_np = u.Quantity(np.array([1.0, 2.0, 3.0]), unit=u.meter)
try:
u.lax.slice(q_np, (0,), (1,))
except BackendError as exc:
print('expected:', exc)
# convert and retry
print(u.lax.slice(q_np.to_jax(), (0,), (1,)))
expected: brainunit.lax.slice requires the jax backend; got numpy-backed Quantity. Call .to_jax() on the input first.
[1.] m
Optional dependencies and graceful failure#
Optional backends are detected lazily. The is_*_array helpers cache
ImportError for the lifetime of the process and never raise:
print('is_jax_array(jnp.zeros(1)) =', u.is_jax_array(jnp.zeros(1)))
print('is_numpy_array(np.zeros(1)) =', u.is_numpy_array(np.zeros(1)))
print('is_cupy_array on a non-cupy obj =', u.is_cupy_array([1, 2, 3]))
print('is_torch_array on a non-torch =', u.is_torch_array([1, 2, 3]))
print('is_dask_array on a non-dask =', u.is_dask_array([1, 2, 3]))
print('is_ndonnx_array on non-ndonnx =', u.is_ndonnx_array([1, 2, 3]))
is_jax_array(jnp.zeros(1)) = True
is_numpy_array(np.zeros(1)) = True
is_cupy_array on a non-cupy obj = False
is_torch_array on a non-torch = False
is_dask_array on a non-dask = False
is_ndonnx_array on non-ndonnx = False
Asking for a backend that isn’t installed raises brainunit.BackendError,
not a bare ImportError. The exception message includes the exact install
command, so guard around the selection if you want graceful fallback:
def pick_backend():
for name, module in [('torch', 'torch'), ('cupy', 'cupy'),
('jax', 'jax'), ('numpy', 'numpy')]:
try:
__import__(module)
return name
except ImportError:
continue
raise RuntimeError('no array backend available')
print('preferred backend:', pick_backend())
preferred backend: torch
Conversion between backends#
Every Quantity has a per-backend conversion method. Each one returns a
new Quantity; the original is untouched. Each one is a no-op (return self) if the mantissa is already on the target backend.
Method |
Notes |
|---|---|
|
Wraps the mantissa with |
|
Materializes ndonnx via |
|
|
|
|
|
Wraps with |
|
|
q_np = u.Quantity(np.array([1.0, 2.0]), unit=u.meter)
q_jax = q_np.to_jax() # NumPy -> JAX
q_back = q_jax.to_numpy() # JAX -> NumPy
print(q_np.backend, '->', q_jax.backend, '->', q_back.backend)
numpy -> jax -> numpy
Installation#
Command |
Provides |
|---|---|
|
core |
|
core + |
|
core + |
|
core + |
|
core + |
|
adds |
|
adds |
|
adds |
|
adds |
|
shorthand for |
JAX is a required dependency — every install includes the JAX backend. The
[cpu] / [cuda12] / [cuda13] / [tpu] extras pin the JAX accelerator
build; pick at most one. The [cupy] / [torch] / [dask] / [ndonnx]
extras are independent and can be combined freely.