We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, I was working with AqtEinsum and in this particular case I got ValueError, altough in jnp.einsum the following operation works fine.
AqtEinsum
jnp.einsum
This works fine:
x = jax.random.normal(key, [1, 2, 4]) w = jax.random.normal(key, [2, 4, 4]) z = jnp.einsum('...ij,hjk->...ik', x, w) z
This is not:
class SimpleDense(nn.Module): features: int config = aqt_config.fully_quantized() @nn.compact def __call__(self, x): d = x.shape[-1] kernel = self.param('kernel', nn.initializers.normal(), (2, d, self.features)) einsum = aqt.AqtEinsum(self.config) return einsum('...ij,hjk->...ik', x, kernel) model = SimpleDense(features = 4) params = model.init(key, x)
ValueError Traceback (most recent call last) [<ipython-input-41-bf39ae22f96a>](https://localhost:8080/#) in <cell line: 2>() 1 model = SimpleDense(features = 4) ----> 2 params = model.init(key, x) [... skipping hidden 9 frame] 1 frames [<ipython-input-40-29c53684ec5e>](https://localhost:8080/#) in __call__(self, x) 10 einsum = aqt.AqtEinsum(self.config) 11 ---> 12 return einsum('...ij,hjk->...ik', x, kernel) [... skipping hidden 2 frame] [/usr/local/lib/python3.10/dist-packages/aqt/jax/v2/flax/aqt_flax.py](https://localhost:8080/#) in __call__(self, eqn, lhs_g, rhs_g) 315 einsum = functools.partial(aqt_dot_general.einsum, eqn=eqn) 316 a = jax.make_jaxpr(einsum)(lhs=lhs_in, rhs=rhs_in) --> 317 [lhs_g_id, rhs_g_id] = a.eqns[0].invars 318 [lhs_l_id, rhs_l_id] = a.jaxpr.invars 319 not_swap = lhs_g_id == lhs_l_id and rhs_g_id == rhs_l_id ValueError: not enough values to unpack (expected 2, got 1)
Also if the einsum subscript and the kernel dimension is the following:
... kernel = self.param('kernel', nn.initializers.normal(), (d, self.features)) einsum = aqt.AqtEinsum(self.config) return einsum('...ij,jk->...ik', x, kernel) ...
The code works as it is expected without any errors.
For mention I'm using aqt version 0.5.0 and the random seed is 42.
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hi, I was working with
AqtEinsum
and in this particular case I got ValueError, altough injnp.einsum
the following operation works fine.This works fine:
This is not:
Also if the einsum subscript and the kernel dimension is the following:
The code works as it is expected without any errors.
For mention I'm using aqt version 0.5.0 and the random seed is 42.
The text was updated successfully, but these errors were encountered: