Skip to content

Commit

Permalink
get compile config back to EdgeProgramManager.transform (#3500)
Browse files Browse the repository at this point in the history
Summary:

This diff introduces the compile configuration into the `EdgeProgramManager.transform` method, ensuring that the constructor and transform function of` EdgeProgramManager` maintain consistent verification configuration. The default verification configuration of `EdgeProgramManager.transform` will now match that of its constructor.

This update brings two key improvements:
1. Enhanced Verification for the transform Function: With this update, we can now customize all necessary configuration verifiers within the `transform` function.
2. Improved Consistency within the `EdgeProgramManager` Class: Post-update, all verifiers within the same `EdgeProgramManager` will have identical functionality. This not only makes the logic more intuitive but also reduces redundancy in user settings.

Reviewed By: JacobSzwejbka

Differential Revision: D56804195
  • Loading branch information
Gasoonjia authored and facebook-github-bot committed May 3, 2024
1 parent 3a2b2e8 commit 29905e7
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 34 deletions.
3 changes: 1 addition & 2 deletions examples/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def export_model(model, example_inputs):

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
[ReplacePT2QuantWithCadenceQuant(), ReplacePT2DequantWithCadenceDequant()],
check_ir_validity=False,
[ReplacePT2QuantWithCadenceQuant(), ReplacePT2DequantWithCadenceDequant()]
)

exec_prog = cadence_prog_manager.to_executorch()
Expand Down
29 changes: 17 additions & 12 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,14 +761,14 @@ def __init__(
Constructs an EdgeProgramManager from an existing set of exported programs in edge dialect.
"""
config = compile_config or EdgeCompileConfig()
self.compile_config = compile_config or EdgeCompileConfig()
if not isinstance(edge_programs, dict):
edge_programs = {"forward": edge_programs}
for name, program in edge_programs.items():
try:
EXIREdgeDialectVerifier(
check_edge_ops=config._use_edge_ops,
enable=config._check_ir_validity,
enable=self.compile_config._check_ir_validity,
check_edge_ops=self.compile_config._use_edge_ops,
)(program.graph_module)
except ExportError as e:
logging.info(f"Input program {name} is not in aten dialect.")
Expand Down Expand Up @@ -800,8 +800,7 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram:
def transform(
self,
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
check_ir_validity: bool = True,
# We should also probably add check_edge_ops here as well
compile_config: Optional[EdgeCompileConfig] = None,
) -> "EdgeProgramManager":
"""
Transforms the program according to the provided passes.
Expand All @@ -813,29 +812,35 @@ def transform(
will be transformed with the provided passes. If it is a
dictionary, only method names specified in the dictionary will be
transformed with their corresponding passes.
compile_config: Compile config to use for veriy the correctness of model
graph after each pass. If not specified, the compile config of the
calling EdgeProgramManager will be used.
Returns:
EdgeProgramManager: A copy of the calling EdgeProgramManager with the
transformations applied.
"""
compile_config = compile_config or self.compile_config
new_programs: Dict[str, ExportedProgram] = {}
if isinstance(passes, dict):
for name, program in self._edge_programs.items():
if name in passes.keys():
new_programs[name] = _transform(program, *passes[name])
EXIREdgeDialectVerifier(enable=check_ir_validity)(
new_programs[name].graph_module
)
EXIREdgeDialectVerifier(
enable=compile_config._check_ir_validity,
check_edge_ops=compile_config._use_edge_ops,
)(new_programs[name].graph_module)
else:
new_programs[name] = copy.deepcopy(program)

else: # apply passes to every method
for name, program in self._edge_programs.items():
new_programs[name] = _transform(program, *passes)
EXIREdgeDialectVerifier(enable=check_ir_validity)(
new_programs[name].graph_module
)
config = EdgeCompileConfig(_check_ir_validity=check_ir_validity)
EXIREdgeDialectVerifier(
enable=compile_config._check_ir_validity,
check_edge_ops=compile_config._use_edge_ops,
)(new_programs[name].graph_module)
config = EdgeCompileConfig(_check_ir_validity=compile_config._check_ir_validity)
return EdgeProgramManager(
new_programs, copy.deepcopy(self._config_methods), config
)
Expand Down
8 changes: 4 additions & 4 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def forward(self, x_raw, h, c):
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)

new_prog = edge_prog.transform([SpecPropPass()], check_ir_validity=False)
new_prog = edge_prog.transform([SpecPropPass()])

new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module)
self.assertIsNotNone(new_gm_res)
Expand Down Expand Up @@ -679,7 +679,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
new_prog = prog.transform([EdgeToBackendOpsPass()], check_ir_validity=False)
new_prog = prog.transform([EdgeToBackendOpsPass()])
self.assertIsNotNone(new_prog.exported_program().graph_module)
converted_gm = new_prog.exported_program().graph_module

Expand Down Expand Up @@ -806,15 +806,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)

new_prog = prog.transform([EdgeToBackendOpsPass()], check_ir_validity=False)
new_prog = prog.transform([EdgeToBackendOpsPass()])
gm = new_prog.exported_program().graph_module
gm.print_readable()
*_, ones, out = gm.graph.nodes
print(f"Before ExportPass: {ones.format_node()}")
self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt))
self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0)

new_prog = new_prog.transform([ExportPass()], check_ir_validity=False)
new_prog = new_prog.transform([ExportPass()])
gm = new_prog.exported_program().graph_module
gm.print_readable()
*_, ones, out = gm.graph.nodes
Expand Down
22 changes: 6 additions & 16 deletions exir/tests/test_quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def forward(self, x, y):
config = EdgeCompileConfig(_check_ir_validity=False)
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
# check that we are using functional variant of q/dq/add
FileCheck().check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default"
Expand Down Expand Up @@ -100,9 +98,7 @@ def forward(self, x, y):
config = EdgeCompileConfig(_check_ir_validity=False)
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
# check that we are using functional variant of q/dq/add/reshape
# make sure we only have two quant and one dequant since the q/dq around reshape
# should be fused
Expand Down Expand Up @@ -157,9 +153,7 @@ def forward(self, x, y):
config = EdgeCompileConfig(_check_ir_validity=False)
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
# check that we are using functional variant of q/dq/add/slice
# make sure we only have one quant and one dequant since the q/dq around slice
# should be fused
Expand Down Expand Up @@ -206,7 +200,7 @@ def forward(self, x, y):
config = EdgeCompileConfig(_check_ir_validity=False)
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform([QuantFusionPass()], check_ir_validity=False)
m = m.transform([QuantFusionPass()])
# check that we are using functional variant of q/dq/cat
FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
Expand Down Expand Up @@ -301,9 +295,7 @@ def forward(self, indices):
)
m = to_edge(export(m, example_inputs), compile_config=compile_config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
# check that we are using functional variant of q/dq/cat
FileCheck().check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
Expand Down Expand Up @@ -359,9 +351,7 @@ def forward(self, indices):
)
m = to_edge(export(m, example_inputs), compile_config=compile_config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
)
m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
# check that we are using functional variant of q/dq/cat
FileCheck().check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
Expand Down

0 comments on commit 29905e7

Please sign in to comment.