from einx._src.adapter.einx_from_namedtensor import _parse_op
from einx._src.adapter.einx_from_namedtensor import Invocation
from einx._src.adapter.einx_from_namedtensor import solve as _solve2
from einx._src.frontend.errors import SyntaxError
from einx._src.namedtensor import ExpressionIndicator
import einx._src.namedtensor.stage3 as stage3
from collections import defaultdict
import numpy as np
import numpy.typing as npt
from collections.abc import Mapping
from .api import _is_scalar
from .types import Tensor
import einx._src.tracer as tracer
import types
import warnings
def _exprs_to_axes(exprs):
values = defaultdict(list)
for root in exprs:
for expr in root.nodes():
if isinstance(expr, stage3.Axis):
tokens = expr.name.split(".")
values[tokens[0]].append((tuple(int(t) for t in tokens[1:]), expr.value))
values2 = {}
for name, xs in values.items():
shape = np.amax([coord for coord, value in xs], axis=0) + 1
value = np.zeros(shape, dtype="int32")
for coord, v in xs:
value[coord] = v
if value.shape == ():
value = int(value)
values2[name] = value
return values2
def _solve(description, tensor_shapes, parameters, reraise, cse):
invocation = Invocation(
description,
name="operation",
tensors=[
tracer.signature.classical.Tensor(None, shape=shape)
if shape is not None
else tracer.signature.classical.ConvertibleTensor(None, shape=None, concrete=types.SimpleNamespace(type=None))
for shape in tensor_shapes
],
kwargs={},
)
if "->" in description:
indicator = ExpressionIndicator(description)
raise SyntaxError(description, pos=indicator.get_pos_for_literal("->"), message="The expression must not contain a '->' operator.\n%EXPR%")
try:
exprs_in, exprs_out = _parse_op(f"{description} ->", el_op=None, invocation=invocation, allow_concat=True)
exprs_in, exprs_out = _solve2(exprs_in, exprs_out, tensor_shapes, invocation, parameters, cse_concat=True, cse=cse)
except Exception:
if reraise:
raise
else:
return None
return exprs_in
def _get_shape(tensor):
if tensor is None:
return None
try:
return tuple(int(x) for x in tensor.shape)
except:
pass
if _is_scalar(tensor):
return ()
elif callable(tensor):
return None
else:
raise ValueError(f"Found {type(tensor)} which is not a valid tensor argument.")
[docs]
def solve_shapes(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> tuple[tuple[int, ...], ...]:
"""Solve for the shapes of the einx expressions under the given constraints.
Args:
description: Comma-separated list of tensor expressions in einx notation.
*tensors: Tensors matching the description string. Accepts ``None`` for unknown
shapes.
**parameters: Additional parameters that specify dimension sizes, e.g. ``a=4``.
Returns:
A tuple of shapes corresponding to the input tensors.
Example:
>>> x = np.random.rand(3, 4)
>>> einx.solve_shapes("a b, c b a", x, None, c=3)
((3, 4), (5, 4, 3))
"""
exprs = _solve(description, [_get_shape(tensor) for tensor in tensors], parameters, reraise=True, cse=True)
return tuple(expr.shape for expr in exprs)
[docs]
def solve_axes(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> Mapping[str, npt.ArrayLike]:
"""Solve for the length of all axes in an expression under the given constraints.
Args:
description: Comma-separated list of tensor expressions in einx notation.
*tensors: Tensors matching the description string. Accepts ``None`` for unknown
shapes.
**parameters: Additional parameters that specify dimension sizes, e.g. ``a=4``.
Returns:
A mapping from axis name to their lengths. If an axis is used with an ellipsis,
the lengths are given as a list of integers.
Example:
>>> x = np.random.rand(3, 4)
>>> einx.solve_axes("a b, c b a", x, None, c=3)
{'a': 3, 'b': 4, 'c': 3}
>>> einx.solve_axes("a..., c a...", x, None, c=3)
{'a': array([3, 4], dtype=int32), 'c': 3}
"""
exprs = _solve(description, [_get_shape(tensor) for tensor in tensors], parameters, reraise=True, cse=False)
return _exprs_to_axes(exprs)
def solve(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> Mapping[str, npt.ArrayLike]:
"""This function is an alias for :func:`einx.solve_axes`."""
return solve_axes(description, *tensors, **parameters)
[docs]
def matches(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> bool:
"""Returns whether the given tensors match the einx expression description under the
given constraints.
Args:
description: Comma-separated list of tensor expressions in einx notation.
*tensors: Tensors matching the description string. Accepts ``None`` for unknown
shapes.
**parameters: Additional parameters that specify dimension sizes, e.g. ``a=4``.
Returns:
True if the tensors and constraints match the description, False otherwise.
Example:
>>> x = np.random.rand(3, 4)
>>> einx.matches("a b", x)
True
>>> einx.matches("a b c", x)
False
"""
try:
solve_shapes(description, *tensors, **parameters)
return True
except:
return False
def check(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> None:
warnings.warn("einx.check is deprecated and will be removed in a future release. Please call einx.id instead.", DeprecationWarning, stacklevel=2)
_solve(description, [_get_shape(tensor) for tensor in tensors], parameters, reraise=True, cse=True)