einx.mlx.adapt_with_vmap#
- einx.mlx.adapt_with_vmap(op, signature=None)[source]#
Adapts an operation to einx notation using
mlx.core.vmap.The operation is expected to have one of the following signatures:
def op(*tensors: Tensor) -> Tensor: ... def op(*tensors: Tensor) -> Tuple[Tensor]: ...
The number and shapes of input and output tensors match the signature of the elementary operation specified in the einx expression (i.e. containing all non-vectorized axes). For example:
@einx.mlx.adapt_with_vmap def einop(x, y): # shape of x is (b, c) # shape of y is (c) ... return ... # shape of result must be (c, b) z = einop("a [b c] x, d [c] -> a d x [c b]", x, y)
- Parameters:
op – The operation that will be adapted to einx notation.
- Returns:
A new operation that follows einx notation and internally invokes the original operation.