Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for exception list in EXIRATenDialectVerifierBase #3481

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 44 additions & 22 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _check_valid_dim_order_ops(op, use_dim_order) -> None:
class EXIRATenDialectVerifierBase(Verifier):
dialect = "OLD_EXIR_ATEN_DISABLED"

def __init__(
self, exception_list: Optional[List[torch._ops.OpOverload]] = None
) -> None:
super().__init__()
self._exception_list = exception_list if exception_list else []

def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (
torch.fx.GraphModule,
Expand All @@ -74,23 +80,33 @@ def __call__(self, *args, **kwargs):
class EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
dialect = "OLD_EXIR_ATEN"

def _get_exception_list(self) -> List[torch._ops.OpOverload]:
exception_list = [
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.aten._upsample_bilinear2d_aa.default,
torch.ops.aten.quantize_per_tensor.default,
torch.ops.aten.dequantize.self,
torch.ops.aten.max.default, # TODO(T188268054)
torch.ops.aten.min.default, # TODO(T188268054)
torch.ops.aten.full_like.default, # TODO(T183507359)
]
exception_list += self._exception_list

return exception_list

def check_valid_op(self, op):
if isinstance(op, OpOverload):
# TODO These special ops should be removable easily.
if op.namespace in (
"quantized_decomposed",
"boltnn_nimble",
"nimble",
"quantized",
"dim_order_ops",
) or op in (
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.aten._upsample_bilinear2d_aa.default,
torch.ops.aten.quantize_per_tensor.default,
torch.ops.aten.dequantize.self,
torch.ops.aten.max.default, # TODO(T188268054)
torch.ops.aten.min.default, # TODO(T188268054)
torch.ops.aten.full_like.default, # TODO(T183507359)
if (
op.namespace
in [
"quantized_decomposed",
"boltnn_nimble",
"nimble",
"quantized",
"dim_order_ops",
]
or op in self._get_exception_list()
):
return
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
Expand Down Expand Up @@ -150,6 +166,7 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
def EXIREdgeDialectVerifier( # noqa: C901
edge_compile_config: Optional[EdgeCompileConfig] = None,
class_only: bool = False,
exception_list: Optional[List[torch._ops.OpOverload]] = None,
):
class _EXIREdgeDialectVerifier(Verifier):
dialect = "EDGE"
Expand All @@ -161,13 +178,14 @@ def __init__(self) -> None:
self.check_edge_ops = _edge_compile_config._use_edge_ops
self.use_dim_order = not _edge_compile_config._skip_dim_order

self.aten_op_verifier = EXIRATenDialectVerifier()
self.aten_op_verifier = EXIRATenDialectVerifier(exception_list)
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op

if self.check_edge_ops:
self.check_valid_op = self.check_valid_edge_op
else:
self.check_valid_op = self.check_valid_aten_op
self._exception_list = exception_list if exception_list else []

def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (
Expand All @@ -183,13 +201,17 @@ def allowed_op_types(self):
def check_valid_edge_op(self, op):
if not self.enable:
return
if op in [
operator.getitem,
torch.ops.aten.sym_size.int,
torch.ops.aten.scalar_tensor.default,
torch.ops.aten._assert_async.msg,
torch.ops.aten._assert_scalar.default,
]:
if (
op
in [
operator.getitem,
torch.ops.aten.sym_size.int,
torch.ops.aten.scalar_tensor.default,
torch.ops.aten._assert_async.msg,
torch.ops.aten._assert_scalar.default,
]
+ self._exception_list
):
return

if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload):
Expand Down