einx.jax.adapt_with_vmap

einx.jax.adapt_with_vmap#

einx.jax.adapt_with_vmap(op, signature=None)[source]#

Adapts an operation to einx notation using jax.vmap.

The operation is expected to have one of the following signatures:

def op(*tensors: Tensor) -> Tensor:
    ...

def op(*tensors: Tensor) -> Tuple[Tensor]:
    ...

The number and shapes of input and output tensors match the signature of the elementary operation specified in the einx expression (i.e. containing all non-vectorized axes). For example:

@einx.jax.adapt_with_vmap
def einop(x, y):
    # shape of x is (b, c)
    # shape of y is (c)
    ...
    return ... # shape of result must be (c, b)

z = einop("a [b c] x, d [c] -> a d x [c b]", x, y)
Parameters:

op – The operation that will be adapted to einx notation.

Returns:

A new operation that follows einx notation and internally invokes the original operation.