Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AqtEinsum 'not enough values to unpack' #498

Open
brandnewchoppa opened this issue Feb 23, 2024 · 0 comments
Open

AqtEinsum 'not enough values to unpack' #498

brandnewchoppa opened this issue Feb 23, 2024 · 0 comments

Comments

@brandnewchoppa
Copy link

Hi, I was working with AqtEinsum and in this particular case I got ValueError, altough in jnp.einsum the following operation works fine.

This works fine:

x = jax.random.normal(key, [1, 2, 4])
w = jax.random.normal(key, [2, 4, 4])

z = jnp.einsum('...ij,hjk->...ik', x, w)
z

This is not:

class SimpleDense(nn.Module):
    features: int
    config = aqt_config.fully_quantized()

    @nn.compact
    def __call__(self, x):
        d = x.shape[-1]

        kernel = self.param('kernel', nn.initializers.normal(), (2, d, self.features))
        einsum = aqt.AqtEinsum(self.config)

        return einsum('...ij,hjk->...ik', x, kernel)

model = SimpleDense(features = 4)
params = model.init(key, x)
ValueError                                Traceback (most recent call last)
[<ipython-input-41-bf39ae22f96a>](https://localhost:8080/#) in <cell line: 2>()
      1 model = SimpleDense(features = 4)
----> 2 params = model.init(key, x)

    [... skipping hidden 9 frame]

1 frames
[<ipython-input-40-29c53684ec5e>](https://localhost:8080/#) in __call__(self, x)
     10         einsum = aqt.AqtEinsum(self.config)
     11 
---> 12         return einsum('...ij,hjk->...ik', x, kernel)

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.10/dist-packages/aqt/jax/v2/flax/aqt_flax.py](https://localhost:8080/#) in __call__(self, eqn, lhs_g, rhs_g)
    315     einsum = functools.partial(aqt_dot_general.einsum, eqn=eqn)
    316     a = jax.make_jaxpr(einsum)(lhs=lhs_in, rhs=rhs_in)
--> 317     [lhs_g_id, rhs_g_id] = a.eqns[0].invars
    318     [lhs_l_id, rhs_l_id] = a.jaxpr.invars
    319     not_swap = lhs_g_id == lhs_l_id and rhs_g_id == rhs_l_id

ValueError: not enough values to unpack (expected 2, got 1)

Also if the einsum subscript and the kernel dimension is the following:

...
kernel = self.param('kernel', nn.initializers.normal(), (d, self.features))
einsum = aqt.AqtEinsum(self.config)

return einsum('...ij,jk->...ik', x, kernel)
...

The code works as it is expected without any errors.

For mention I'm using aqt version 0.5.0 and the random seed is 42.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant