einx.solve_shapes

Contents

einx.solve_shapes#

einx.solve_shapes(description, *tensors, **parameters)[source]#

Solve for the shapes of the einx expressions under the given constraints.

Parameters:
  • description (str) – Comma-separated list of tensor expressions in einx notation.

  • *tensors (Tensor) – Tensors matching the description string. Accepts None for unknown shapes.

  • **parameters (Union[_Buffer, _SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[bool | int | float | complex | str | bytes]]) – Additional parameters that specify dimension sizes, e.g. a=4.

Return type:

tuple[tuple[int, ...], ...]

Returns:

A tuple of shapes corresponding to the input tensors.

Example

>>> x = np.random.rand(3, 4)
>>> einx.solve_shapes("a b, c b a", x, None, c=3)
((3, 4), (5, 4, 3))