Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two consecutive Unsqueeze operations in QAT model throws error at compilation time #682

Open
raeell opened this issue May 16, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@raeell
Copy link

raeell commented May 16, 2024

Summary

When using two consecutive reshape operations in a QAT model, such as two torch.unsqueeze operations, the compilation throws the ValueError "Could not determine a unique scale for the quantization! Please check the ONNX graph of this model.", even though the two torch.unsqueeze are between two QuantIdentity layers.
It seems that this especially happens when passing a parameter n_bits to compile_brevitas_qat_model that is smaller than the bit width used for the QAT model.
For instance, using a bit width of 8 for the QuantIdentity layers in the QAT model, but choosing n_bits=6.

Description

  • versions affected: 1.5.0
  • python version: 3.9.16
Minimal code to reproduce the bug:

import brevitas.nn as qnn
import torch
import torch.nn as nn
from concrete.ml.torch.compile import compile_brevitas_qat_model

class Unsqueeze(nn.Module):

    def __init__(self, bit_width):
        super().__init__()
        self.id1 = qnn.QuantIdentity(bit_width=bit_width)
        self.conv1 = qnn.QuantConv2d(1, 1, 1, bit_width=bit_width, bias=False)

    def forward(self, x):
        """Forward pass of the model."""
        x = self.id1(x)
        x = x.unsqueeze(1)
        x = x.unsqueeze(1)
        x = self.id1(x)
        x = self.conv1(x)

        return x
    
model = Unsqueeze(bit_width=8)
tensor_ = torch.randn(1, 200)
compile_brevitas_qat_model(model, tensor_, verbose=False, n_bits=8)
print("Compilation with 8 bits successful")
compile_brevitas_qat_model(model, tensor_, verbose=False, n_bits=7)
print("Compilation with 7 bits successful")
try:
    compile_brevitas_qat_model(model, tensor_, verbose=False, n_bits=6)
except Exception as e:
    print(e)
    print("Compilation with 6 bits failed")

@raeell raeell added the bug Something isn't working label May 16, 2024
@andrei-stoian-zama
Copy link
Collaborator

Thank you for raising this issue, we'll look into it!

@jfrery
Copy link
Collaborator

jfrery commented May 22, 2024

Hi @raeell,

You are correct the error message isn't really clear here. Some assertion seem to sometimes fail when setting less bits in the compilation than in brevitas layers.

Brevitas already has all the necessary n_bits information so the n_bits parameter at compilation time should be hidden from the user. We will improve this. Thanks for the issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants