Source code for saiunit.autograd._vector_grad

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from functools import wraps
from typing import Callable, Sequence

import jax
import numpy as np
from jax import numpy as jnp

from saiunit._base_getters import get_mantissa, get_unit, maybe_decimal
from saiunit._base_quantity import Quantity
from saiunit._misc import maybe_custom_array_tree
from ._misc import _check_callable, _argnums_partial, _ensure_index

__all__ = [
    'vector_grad',
]


[docs] def vector_grad( func: Callable, argnums: int | Sequence[int] = 0, return_value: bool = False, has_aux: bool = False, unit_aware: bool = True, ): """ Unit-aware element-wise gradient of a vector-valued function. Unlike :func:`grad` (which requires scalar outputs), ``vector_grad`` computes gradients for vector-valued functions by using a VJP with an all-ones cotangent vector. The result is the column-wise sum of the Jacobian (i.e. ``ones @ J``); this coincides with the element-wise derivative — the diagonal of the Jacobian — only when ``func`` is applied element-wise. Parameters ---------- func : callable A Python callable that computes a vector output given arguments (possibly carrying physical units). argnums : int or tuple of int, optional Specifies which positional argument(s) to differentiate with respect to. Default is ``0``. return_value : bool, optional If ``True``, the returned function also returns the function value. Default is ``False``. has_aux : bool, optional If ``True``, ``func`` is expected to return ``(output, aux)`` where only ``output`` is differentiated. Default is ``False``. unit_aware : bool, optional If ``True``, physical units are propagated through the differentiation. Default is ``True``. Returns ------- grad_fun : callable A function with the same signature as ``func`` that returns the element-wise gradient. The exact return shape depends on ``return_value`` and ``has_aux``: - Default: ``gradient`` - ``return_value=True``: ``(gradient, value)`` - ``has_aux=True``: ``(gradient, aux)`` - Both: ``(gradient, value, aux)`` Notes ----- When ``unit_aware=True``, ``func`` must return a single array (not a pytree with multiple leaves). See Also -------- grad : Gradient for scalar-valued functions. jacrev : Full Jacobian via reverse-mode AD. Examples -------- Element-wise gradient of a squared function with units: .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 >>> vg_fn = suauto.vector_grad(f) >>> vg_fn(jnp.array([3.0, 4.0]) * u.ms) [6.0, 8.0] * ms Returning both the gradient and the function value: .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 >>> vg_fn = suauto.vector_grad(f, return_value=True) >>> grad, value = vg_fn(jnp.array([3.0, 4.0]) * u.ms) >>> grad [6.0, 8.0] * ms >>> value [9.0, 16.0] * ms ** 2 """ _check_callable(func) argnums = _ensure_index(argnums) @wraps(func) def grad_fun(*args, **kwargs): args, kwargs = maybe_custom_array_tree((args, kwargs)) argnums_, f_partial, dyn_args = _argnums_partial(func, argnums, args, kwargs) if has_aux: y, vjp_fn, aux = jax.vjp(f_partial, *dyn_args, has_aux=True) else: y, vjp_fn = jax.vjp(f_partial, *dyn_args) # ``leaves``/``tree`` flatten through Quantity (raw arrays) and are used # to build a cotangent matching ``y``'s structure. ``out_leaves`` treats # each Quantity as a single leaf so the output unit and arity are read # from the quantities themselves, not from a (unitless) container. leaves, tree = jax.tree.flatten(y) out_leaves, _ = jax.tree.flatten(y, is_leaf=lambda x: isinstance(x, Quantity)) # Differentiating a non-inexact (e.g. integer or boolean) output yields # meaningless zero gradients; reject it like ``jax.grad`` does. for leaf in out_leaves: dtype = getattr(get_mantissa(leaf), 'dtype', None) if dtype is not None and not jax.dtypes.issubdtype(dtype, np.inexact): raise TypeError( f'vector_grad requires real- or complex-valued (inexact) outputs, ' f'but got {np.dtype(dtype).name}.' ) if unit_aware: if len(out_leaves) != 1: raise ValueError( f'vector_grad with unit_aware=True requires the function to return a single ' f'array, but got {len(out_leaves)} outputs.' ) tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves]) grads = vjp_fn(tangents) if isinstance(argnums_, int): grads = grads[0] if unit_aware: args_to_grad = jax.tree.map(lambda i: args[i], argnums_) r_unit = get_unit(out_leaves[0]) grads = jax.tree.map( lambda arg, grad: maybe_decimal( Quantity(get_mantissa(grad), unit=r_unit / get_unit(arg)) ), args_to_grad, grads, is_leaf=lambda x: isinstance(x, Quantity) ) else: # Without unit awareness, strip any unit the VJP carried back from # unit-ful inputs so the gradient does not advertise a wrong unit. grads = jax.tree.map( get_mantissa, grads, is_leaf=lambda x: isinstance(x, Quantity) ) if has_aux: return (grads, y, aux) if return_value else (grads, aux) else: return (grads, y) if return_value else grads return grad_fun