Source code for einx._src.frontend.util

from einx._src.adapter.einx_from_namedtensor import _parse_op
from einx._src.adapter.einx_from_namedtensor import Invocation
from einx._src.adapter.einx_from_namedtensor import solve as _solve2
from einx._src.frontend.errors import SyntaxError
from einx._src.namedtensor import ExpressionIndicator
import einx._src.namedtensor.stage3 as stage3
from collections import defaultdict
import numpy as np
import numpy.typing as npt
from collections.abc import Mapping
from .api import _is_scalar
from .types import Tensor
import einx._src.tracer as tracer
import types
import warnings


def _exprs_to_axes(exprs):
    values = defaultdict(list)
    for root in exprs:
        for expr in root.nodes():
            if isinstance(expr, stage3.Axis):
                tokens = expr.name.split(".")
                values[tokens[0]].append((tuple(int(t) for t in tokens[1:]), expr.value))

    values2 = {}
    for name, xs in values.items():
        shape = np.amax([coord for coord, value in xs], axis=0) + 1
        value = np.zeros(shape, dtype="int32")
        for coord, v in xs:
            value[coord] = v
        if value.shape == ():
            value = int(value)
        values2[name] = value

    return values2


def _solve(description, tensor_shapes, parameters, reraise, cse):
    invocation = Invocation(
        description,
        name="operation",
        tensors=[
            tracer.signature.classical.Tensor(None, shape=shape)
            if shape is not None
            else tracer.signature.classical.ConvertibleTensor(None, shape=None, concrete=types.SimpleNamespace(type=None))
            for shape in tensor_shapes
        ],
        kwargs={},
    )
    if "->" in description:
        indicator = ExpressionIndicator(description)
        raise SyntaxError(description, pos=indicator.get_pos_for_literal("->"), message="The expression must not contain a '->' operator.\n%EXPR%")
    try:
        exprs_in, exprs_out = _parse_op(f"{description} ->", el_op=None, invocation=invocation, allow_concat=True)
        exprs_in, exprs_out = _solve2(exprs_in, exprs_out, tensor_shapes, invocation, parameters, cse_concat=True, cse=cse)
    except Exception:
        if reraise:
            raise
        else:
            return None
    return exprs_in


def _get_shape(tensor):
    if tensor is None:
        return None
    try:
        return tuple(int(x) for x in tensor.shape)
    except:
        pass
    if _is_scalar(tensor):
        return ()
    elif callable(tensor):
        return None
    else:
        raise ValueError(f"Found {type(tensor)} which is not a valid tensor argument.")


[docs] def solve_shapes(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> tuple[tuple[int, ...], ...]: """Solve for the shapes of the einx expressions under the given constraints. Args: description: Comma-separated list of tensor expressions in einx notation. *tensors: Tensors matching the description string. Accepts ``None`` for unknown shapes. **parameters: Additional parameters that specify dimension sizes, e.g. ``a=4``. Returns: A tuple of shapes corresponding to the input tensors. Example: >>> x = np.random.rand(3, 4) >>> einx.solve_shapes("a b, c b a", x, None, c=3) ((3, 4), (5, 4, 3)) """ exprs = _solve(description, [_get_shape(tensor) for tensor in tensors], parameters, reraise=True, cse=True) return tuple(expr.shape for expr in exprs)
[docs] def solve_axes(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> Mapping[str, npt.ArrayLike]: """Solve for the length of all axes in an expression under the given constraints. Args: description: Comma-separated list of tensor expressions in einx notation. *tensors: Tensors matching the description string. Accepts ``None`` for unknown shapes. **parameters: Additional parameters that specify dimension sizes, e.g. ``a=4``. Returns: A mapping from axis name to their lengths. If an axis is used with an ellipsis, the lengths are given as a list of integers. Example: >>> x = np.random.rand(3, 4) >>> einx.solve_axes("a b, c b a", x, None, c=3) {'a': 3, 'b': 4, 'c': 3} >>> einx.solve_axes("a..., c a...", x, None, c=3) {'a': array([3, 4], dtype=int32), 'c': 3} """ exprs = _solve(description, [_get_shape(tensor) for tensor in tensors], parameters, reraise=True, cse=False) return _exprs_to_axes(exprs)
def solve(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> Mapping[str, npt.ArrayLike]: """This function is an alias for :func:`einx.solve_axes`.""" return solve_axes(description, *tensors, **parameters)
[docs] def matches(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> bool: """Returns whether the given tensors match the einx expression description under the given constraints. Args: description: Comma-separated list of tensor expressions in einx notation. *tensors: Tensors matching the description string. Accepts ``None`` for unknown shapes. **parameters: Additional parameters that specify dimension sizes, e.g. ``a=4``. Returns: True if the tensors and constraints match the description, False otherwise. Example: >>> x = np.random.rand(3, 4) >>> einx.matches("a b", x) True >>> einx.matches("a b c", x) False """ try: solve_shapes(description, *tensors, **parameters) return True except: return False
def check(description: str, *tensors: Tensor, **parameters: npt.ArrayLike) -> None: warnings.warn("einx.check is deprecated and will be removed in a future release. Please call einx.id instead.", DeprecationWarning, stacklevel=2) _solve(description, [_get_shape(tensor) for tensor in tensors], parameters, reraise=True, cse=True)