Skip to content

Commit

Permalink
to_edge_transform_and_lower (#3483)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3483

This diff introduces the to_edge_transform_and_lower API. The changes introduces are:
- Adding support to the Parititioner class to register ops that it doesn't want to be composed
- Changes to _program.py to add the implementation of to_edge_transform_and_lower()
- Added a basic test case to test that Linear, SDPA & Linear + SDPA are not decomposed when asked and the corresponding backend consumes them.

Differential Revision: D56401086
  • Loading branch information
tarun292 authored and facebook-github-bot committed May 2, 2024
1 parent abdddef commit 4ec6698
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 11 deletions.
14 changes: 14 additions & 0 deletions exir/backend/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from types import MappingProxyType
from typing import Dict, List, Mapping, NamedTuple, Union

import torch

from executorch.exir.backend.backend_details import enforcedmethod
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export import ExportedProgram
Expand Down Expand Up @@ -91,3 +93,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
"""
pass

def ops_to_not_decompose(self) -> List[torch._ops.OpOverload]:
"""
Returns a list of operator names that should not be decomposed. When these ops are
registered and the backend is invoked through to_edge_transform_and_lower it will be
guaranteed that the program that the backend receives will not have any of these ops
decomposed.
Returns:
List[torch._ops.OpOverload]: a list of operator names that should not be decomposed.
"""
pass
16 changes: 10 additions & 6 deletions exir/backend/test/backend_with_compiler_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,19 @@ def preprocess(
processed_bytes = ""
number_of_instruction = 0
debug_handle_map = {}
match_ops = [
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.add.Tensor,
torch.ops.aten.sin.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.scaled_dot_product_attention.default,
]

for node in edge_program.graph.nodes:
if node.op == "call_function":
# TODO(gasoonjia): remove the support of torch.ops.aten.sin.default after migrate serde to edge dialect.
if (
node.target == exir_ops.edge.aten.sin.default
or node.target == exir_ops.edge.aten.mm.default
or node.target == exir_ops.edge.aten.add.Tensor
or node.target == torch.ops.aten.sin.default
):
if node.target in match_ops:
simple_op = DemoOp(
node.target.__name__,
int(torch.prod(torch.tensor(node.meta["val"].shape), 0).item()),
Expand Down
60 changes: 59 additions & 1 deletion exir/backend/test/op_partitioner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, final
from typing import Dict, final, List

import torch
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
Expand Down Expand Up @@ -121,3 +121,61 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
return PartitionResult(
tagged_exported_program=edge_exported_program, partition_tags=partition_tags
)


ops_not_to_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.scaled_dot_product_attention.default,
]

edge_ops_non_decomposed = [
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.scaled_dot_product_attention.default,
]


class OpsToNotDecomposeOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in edge_ops_non_decomposed


@final
class NonDecompTestPartitioner(Partitioner):
"""
Partitions all add/mul nodes regardless of order
"""

def __init__(self) -> None:
self.op_support = any_chain(OpsToNotDecomposeOperatorSupport())
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[CompileSpec("max_value", bytes([4]))],
)

def ops_to_not_decompose(self) -> List[torch._ops.OpOverload]:
return ops_not_to_decompose

def _partition_graph_module(
self,
graph_module: torch.fx.GraphModule,
) -> Dict[str, DelegationSpec]:
partition_tags: Dict[str, DelegationSpec] = {}
partition_list = generate_pattern_op_partitions(
graph_module, op_support=self.op_support
)
for partition in partition_list:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec

for _, submodule, _ in get_control_flow_submodules(graph_module):
ret_partition_tags = self._partition_graph_module(submodule)
partition_tags.update(ret_partition_tags)
return partition_tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tags = self._partition_graph_module(exported_program.graph_module)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
1 change: 1 addition & 0 deletions exir/program/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python_library(
deps = [
"//caffe2:torch",
"//executorch/exir:error",
"//executorch/exir:graph_module",
"//executorch/exir:pass_manager",
"//executorch/exir:print_program",
"//executorch/exir:schema",
Expand Down
2 changes: 2 additions & 0 deletions exir/program/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
ExecutorchProgramManager,
ExirExportedProgram,
to_edge,
to_edge_transform_and_lower,
)

__all__ = [
"ExirExportedProgram",
"ExecutorchProgram",
"_to_edge",
"to_edge",
"to_edge_transform_and_lower",
"edge_to_executorch_passes",
"EdgeProgramManager",
"ExecutorchProgramManager",
Expand Down
201 changes: 199 additions & 2 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from executorch.exir.emit import emit_program, EmitterOutput
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
from executorch.exir.error import ExportError
from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.pass_manager import PassType
from executorch.exir.passes import (
base_post_op_replace_passes,
Expand Down Expand Up @@ -69,6 +70,17 @@

Val = Any

from torch.library import Library

# This is the reserved namespace that is used to register ops to that will
# be prevented from being decomposed during to_edge_transform_and_lower.
edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP"
lib = Library(edge_no_decomp_namespace, "DEF")
# Map from aten ops to the transformed ops registered in the edge_no_decomp_namespace.
aten_op_to_transform_op = {}
# Map from the transformed ops registered in the edge_no_decomp_namespace to aten ops.
transform_op_to_aten_op = {}


def _get_updated_range_constraints(gm):
def get_shape_env(gm):
Expand Down Expand Up @@ -655,12 +667,15 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]


def _generate_edge_program(
name: str, config: EdgeCompileConfig, program: ExportedProgram
name: str,
config: EdgeCompileConfig,
program: ExportedProgram,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
) -> ExportedProgram:

if config._check_ir_validity:
try:
EXIRATenDialectVerifier()(program.graph_module)
EXIRATenDialectVerifier(ops_set_to_not_decompose)(program.graph_module)
except ExportError as e:
logging.info(f"Input program {name} is not in ATen dialect.")
raise e
Expand Down Expand Up @@ -693,6 +708,7 @@ def _generate_edge_program(
check_edge_ops=config._use_edge_ops,
enable=config._check_ir_validity,
class_only=True,
exception_list=ops_set_to_not_decompose,
),
constants=program.constants,
)
Expand All @@ -703,6 +719,185 @@ def _generate_edge_program(
return edge_program


def to_edge_transform_and_lower(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
transform_passes: Optional[
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
] = None,
partitioner: Optional[Union[Partitioner, Dict[str, Partitioner]]] = None,
constant_methods: Optional[Dict[str, Any]] = None,
compile_config: Optional[EdgeCompileConfig] = None,
) -> "EdgeProgramManager":
"""
:func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of
exported programs in ATen dialect. It differs fundamentally from to_edge in that it
combines the conversion of the ATen dialect to the edge dialect program, then running
the transformation passes and then subsequently lowering the programs to their
corresponding backends all in a single pass.
This is fundamentally useful for lowering to backends that have ops registered that they
do not want to be decomposed and thus rely on matching with these non-decomposed ops. For
these sorts of backends this is the *only* API that should be used to lower to the edge
dialect. Using a combination of to_edge(...) and to_backend(...) will result in inconsistent
or wrong behavior.
Args:
programs: Can be a single ExportedProgram or a dictionary mapping function names
to their corresponding ExportedPrograms. If only a single ExportedProgram is
provided it will be assigned the name "forward".
transform_passes: The passes can either be a list of passes, or a dictionary
mapping method names to lists of passes. If it is just a list of passes, all methods
in the given EdgeProgramManager will be transformed with the provided passes. If it
is a dictionary, only method names specified in the dictionary will be transformed
with their corresponding passes.
partitioner: The partitioner can either be a Partitioner subclass instance, or a
dictionary mapping method names to Partitioner subclass instance. If it is a
Partitioner subclass, all programs in the given EdgeProgramManager will be lowered
using the given partitioner. If it is a dictionary, only method names specified in
the dictionary will be lowered with the given partitioner.
constant_methods: An optional dictionary of method name to the constant value
returned by that method in eager mode. Often used to store config information on
Edge models.
compile_config: An optional argument used to provide greater control over the
transformation to edge dialect process.
Returns:
EdgeProgramManager
"""
ops_set_to_not_decompose = set()

assert not isinstance(constant_methods, EdgeCompileConfig)
config = compile_config or EdgeCompileConfig()
if not isinstance(programs, dict):
aten_programs = {"forward": programs}
else:
aten_programs = programs

# Collect all the ops that were registered by the partitioners to not be decomposed,
# and add them to ops_set_to_not_decompose.
if isinstance(partitioner, dict):
for curr_partitioner in partitioner.values():
ops_set_to_not_decompose = ops_set_to_not_decompose.union(
set(curr_partitioner.ops_to_not_decompose())
)
elif isinstance(partitioner, Partitioner):
ops_set_to_not_decompose = ops_set_to_not_decompose.union(
set(partitioner.ops_to_not_decompose())
)

edge_programs: Dict[str, ExportedProgram] = {}
for name, program in aten_programs.items():

# Returns the op in edge_no_decomp_namespace namespace for the aten
# op that is passed in.
def get_transformed_op(op_aten):
op_name = op_aten._schema.name.split("::")[1]
overload_name = op_aten._schema.overload_name
assert hasattr(torch.ops, edge_no_decomp_namespace)
op_namespace = getattr(torch.ops, edge_no_decomp_namespace)
op = getattr(op_namespace, op_name)
return getattr(op, overload_name)

for op_aten in ops_set_to_not_decompose:
# Check if the op is already cached in the table. If not, then we need to
# create a new op in the edge_no_decomp_namespace namespace.
if aten_op_to_transform_op.get(op_aten) is None:
# Extract the schema from the aten op.
op_schema = str(op_aten._schema).split("::")[1]
op_name = op_aten._schema.name.split("::")[1]
# Define an op in the edge_no_decomp_namespace namespace with the aten schema.
lib.define(op_schema)
# Define the implementation of the op in the edge_no_decomp_namespace namespace.
# Important to note that the implementation of the op is the same as the aten op.
lib.impl(op_name, op_aten, "CompositeExplicitAutograd")
# Cache the aten op and transformed op in their corresponding tables for future use.
aten_op_to_transform_op[op_aten] = get_transformed_op(op_aten)
transform_op_to_aten_op[str(aten_op_to_transform_op[op_aten])] = op_aten

# Iterate through the graph and replace the aten ops with the corresponding
# transformed ops.
for node in program.graph.nodes:
if node.op == "call_function" and node.target in ops_set_to_not_decompose:
node.target = aten_op_to_transform_op[node.target]
for _, submod, _ in get_control_flow_submodules(program.graph_module):
for node in submod.graph.nodes:
if (
node.op == "call_function"
and node.target in ops_set_to_not_decompose
):
node.target = aten_op_to_transform_op[node.target]

program = program.run_decompositions(_default_decomposition_table())

# Iterate through the graph and replace back the transformed ops with their
# corresponding aten ops.
for node in program.graph.nodes:
if (
node.op == "call_function"
and str(node.target) in transform_op_to_aten_op
):
node.target = transform_op_to_aten_op[str(node.target)]
for _, submod, _ in get_control_flow_submodules(program.graph_module):
for node in submod.graph.nodes:
if (
node.op == "call_function"
and str(node.target) in transform_op_to_aten_op
):
node.target = transform_op_to_aten_op[str(node.target)]

edge_programs[name] = program

edge_programs[name] = _generate_edge_program(
name, config, program, list(ops_set_to_not_decompose)
)

edge_manager = EdgeProgramManager(
edge_programs, constant_methods, config, list(ops_set_to_not_decompose)
)

if transform_passes is not None:
edge_manager = edge_manager.transform(transform_passes)

if partitioner is not None:
edge_manager = edge_manager.to_backend(partitioner)

for name, program in edge_manager._edge_programs.items():
if config._check_ir_validity:
try:
EXIREdgeDialectVerifier(
check_edge_ops=config._use_edge_ops,
enable=config._check_ir_validity,
class_only=True,
)()(program.graph_module)
except ExportError as e:
logging.info(
f"Program {name} after transform and lower is not in Edge dialect."
)
raise e

# Check that the ops that were registered to not be decomposed are not present in the
# graph anymore as the transform passes and backends should have consumed them by now.
for node in program.graph_module.graph.nodes:
if node.op == "call_function" and node.target in ops_set_to_not_decompose:
raise RuntimeError(
f"Found {node.target} in edge dialect program {name}. This should have either been consumed by the partitioner or one of the transformation passes."
)
for _, submod, _ in get_control_flow_submodules(program.graph_module):
for node in submod.graph.nodes:
if (
node.op == "call_function"
and node.target in ops_set_to_not_decompose
):
raise RuntimeError(
f"Found {node.target} in edge dialect program {name}. This should have either been consumed by the partitioner or one of the transformation passes."
)

return edge_manager


def to_edge(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -755,6 +950,7 @@ def __init__(
edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
compile_config: Optional[EdgeCompileConfig] = None,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
):
"""
Should not be called directly by users. User should use :func:'to_edge' instead.
Expand All @@ -769,6 +965,7 @@ def __init__(
EXIREdgeDialectVerifier(
check_edge_ops=config._use_edge_ops,
enable=config._check_ir_validity,
exception_list=ops_set_to_not_decompose,
)(program.graph_module)
except ExportError as e:
logging.info(f"Input program {name} is not in aten dialect.")
Expand Down

0 comments on commit 4ec6698

Please sign in to comment.