Examples of compiled code#
The following are examples of various einx operations along with the Python code snippet that einx compiles for them, using either
the default backend or explicitly specifying a backend. The compiled code can be inspected by passing graph=True to the einx operation.
Axis permutation#
The operation
x = np.zeros((10, 5, 2))
einx.id("a b c -> b c a", x)
compiles to the following code:
With backend="numpy"#
import numpy as np
def op(a):
a = np.transpose(a, (1, 2, 0))
return a
With backend="torch"#
import torch
def op(a):
a = torch.asarray(a, device=None)
a = torch.permute(a, (1, 2, 0))
return a
With backend="jax"#
import jax.numpy as jnp
def op(a):
a = jnp.transpose(a, (1, 2, 0))
return a
With backend="arrayapi"#
import array_api_compat
def op(a):
b = array_api_compat.array_namespace(a)
c = b.permute_dims(a, (1, 2, 0))
return c
Axis flattening#
The operation
x = np.zeros((10, 5))
einx.id("(a b) c -> a (b c)", x, b=2)
compiles to the following code:
import numpy as np
def op(a):
a = np.reshape(a, (5, 10))
return a
No-op#
The operation
x = np.zeros((10, 5))
einx.id("a b -> a b", x)
compiles to the following code:
def op(a):
return a
Element-wise multiplication#
The operation
x = jnp.zeros((2, (5 * 6)))
y = jnp.zeros((4, 3, 6))
einx.multiply("a (d e), c b e -> a b c d e", x, y)
compiles to the following code:
With backend="jax.numpylike"#
import jax.numpy as jnp
def op(a, b):
a = jnp.reshape(a, (2, 1, 1, 5, 6))
b = jnp.transpose(b, (1, 0, 2))
b = jnp.reshape(b, (1, 3, 4, 1, 6))
c = jnp.multiply(a, b)
return c
With backend="jax.vmap"#
import jax.numpy as jnp
import jax
def op(a, b):
c = jax.vmap(jnp.multiply, in_axes=(None, 0), out_axes=0)
c = jax.vmap(c, in_axes=(None, 0), out_axes=1)
c = jax.vmap(c, in_axes=(0, 2), out_axes=2)
c = jax.vmap(c, in_axes=(0, None), out_axes=2)
c = jax.vmap(c, in_axes=(0, None), out_axes=0)
a = jnp.reshape(a, (2, 5, 6))
d = c(a, b)
return d
With backend="jax.einsum"#
import jax.numpy as jnp
def op(a, b):
a = jnp.reshape(a, (2, 5, 6))
c = jnp.einsum("abc,dec->aedbc", a, b)
return c
Dot-product#
The operation
x = jnp.zeros((2, 3))
y = jnp.zeros((4, 3))
einx.dot("a [b], c [b] -> c a", x, y)
compiles to the following code:
With backend="jax.numpylike"#
import jax.numpy as jnp
def op(a, b):
a = jnp.reshape(a, (1, 2, 3))
b = jnp.transpose(b, (1, 0))
b = jnp.reshape(b, (1, 3, 4))
c = jnp.matmul(a, b)
c = jnp.reshape(c, (2, 4))
c = jnp.transpose(c, (1, 0))
return c
With backend="jax.vmap"#
import jax.numpy as jnp
import jax
def op(a, b):
c = jax.vmap(jnp.dot, in_axes=(None, 0), out_axes=0)
c = jax.vmap(c, in_axes=(0, None), out_axes=1)
d = c(a, b)
return d
With backend="jax.einsum"#
import jax.numpy as jnp
def op(a, b):
c = jnp.einsum("ab,cb->ca", a, b)
return c
Indexing#
The operation
x = jnp.zeros((2, 128, 128, 3))
y = jnp.zeros((50, 2))
einx.get_at("b [h w] c, p [2] -> b p c", x, y)
compiles to the following code:
With backend="jax.numpylike"#
import jax.numpy as jnp
def op(a, b):
a = jnp.reshape(a, (98304,))
c = jnp.arange(2, dtype="int32")
c = jnp.multiply(c, 49152)
c = jnp.reshape(c, (2, 1, 1))
d = jnp.multiply(b[:, 0], 384)
d = jnp.reshape(d, (1, 50, 1))
e = jnp.add(c, d)
b = jnp.multiply(b[:, 1], 3)
b = jnp.reshape(b, (1, 50, 1))
f = jnp.add(e, b)
g = jnp.arange(3, dtype="int32")
g = jnp.reshape(g, (1, 1, 3))
h = jnp.add(f, g)
i = jnp.take(a, h)
return i
With backend="jax.vmap"#
import jax
def c(d, e):
return d[e[0], e[1]]
def op(a, b):
f = jax.vmap(c, in_axes=(None, 0), out_axes=0)
f = jax.vmap(f, in_axes=(2, None), out_axes=1)
f = jax.vmap(f, in_axes=(0, None), out_axes=0)
g = f(a, b)
return g