einx.solve_axes

Contents

einx.solve_axes#

einx.solve_axes(description, *tensors, **parameters)[source]#

Solve for the length of all axes in an expression under the given constraints.

Parameters:
  • description (str) – Comma-separated list of tensor expressions in einx notation.

  • *tensors (Tensor) – Tensors matching the description string. Accepts None for unknown shapes.

  • **parameters (Union[_Buffer, _SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[bool | int | float | complex | str | bytes]]) – Additional parameters that specify dimension sizes, e.g. a=4.

Return type:

Mapping[str, Union[_Buffer, _SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[bool | int | float | complex | str | bytes]]]

Returns:

A mapping from axis name to their lengths. If an axis is used with an ellipsis, the lengths are given as a list of integers.

Example

>>> x = np.random.rand(3, 4)
>>> einx.solve_axes("a b, c b a", x, None, c=3)
{'a': 3, 'b': 4, 'c': 3}
>>> einx.solve_axes("a..., c a...", x, None, c=3)
{'a': array([3, 4], dtype=int32), 'c': 3}