Skip to content

Commit

Permalink
get compile config back to EdgeProgramManager.transform (pytorch#3500)
Browse files Browse the repository at this point in the history
Summary:

This diff brings compile config back to `EdgeProgramManager.transform` function, to make the EdgeDialectVerifier verify graph in a finer granularity, and not introduce too many attritute to the `transform` API.

Differential Revision: D56804195
  • Loading branch information
Gasoonjia authored and facebook-github-bot committed May 3, 2024
1 parent 53078c4 commit 5c09199
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
6 changes: 5 additions & 1 deletion examples/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from ....exir import EdgeCompileConfig

from ...portable.utils import save_pte_program

from .compiler import export_to_edge
Expand Down Expand Up @@ -55,7 +57,9 @@ def export_model(model, example_inputs):
# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
[ReplacePT2QuantWithCadenceQuant(), ReplacePT2DequantWithCadenceDequant()],
check_ir_validity=False,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)

exec_prog = cadence_prog_manager.to_executorch()
Expand Down
10 changes: 5 additions & 5 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,8 +800,7 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram:
def transform(
self,
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
check_ir_validity: bool = True,
# We should also probably add check_edge_ops here as well
compile_config: Optional[EdgeCompileConfig] = None,
) -> "EdgeProgramManager":
"""
Transforms the program according to the provided passes.
Expand All @@ -818,12 +817,13 @@ def transform(
EdgeProgramManager: A copy of the calling EdgeProgramManager with the
transformations applied.
"""
compile_config = compile_config or EdgeCompileConfig()
new_programs: Dict[str, ExportedProgram] = {}
if isinstance(passes, dict):
for name, program in self._edge_programs.items():
if name in passes.keys():
new_programs[name] = _transform(program, *passes[name])
EXIREdgeDialectVerifier(enable=check_ir_validity)(
EXIREdgeDialectVerifier(enable=compile_config._check_ir_validity)(
new_programs[name].graph_module
)
else:
Expand All @@ -832,10 +832,10 @@ def transform(
else: # apply passes to every method
for name, program in self._edge_programs.items():
new_programs[name] = _transform(program, *passes)
EXIREdgeDialectVerifier(enable=check_ir_validity)(
EXIREdgeDialectVerifier(enable=compile_config._check_ir_validity)(
new_programs[name].graph_module
)
config = EdgeCompileConfig(_check_ir_validity=check_ir_validity)
config = EdgeCompileConfig(_check_ir_validity=compile_config._check_ir_validity)
return EdgeProgramManager(
new_programs, copy.deepcopy(self._config_methods), config
)
Expand Down
28 changes: 24 additions & 4 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,12 @@ def forward(self, x_raw, h, c):
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)

new_prog = edge_prog.transform([SpecPropPass()], check_ir_validity=False)
new_prog = edge_prog.transform(
[SpecPropPass()],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)

new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module)
self.assertIsNotNone(new_gm_res)
Expand Down Expand Up @@ -679,7 +684,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
new_prog = prog.transform([EdgeToBackendOpsPass()], check_ir_validity=False)
new_prog = prog.transform(
[EdgeToBackendOpsPass()],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
self.assertIsNotNone(new_prog.exported_program().graph_module)
converted_gm = new_prog.exported_program().graph_module

Expand Down Expand Up @@ -806,15 +816,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)

new_prog = prog.transform([EdgeToBackendOpsPass()], check_ir_validity=False)
new_prog = prog.transform(
[EdgeToBackendOpsPass()],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
gm = new_prog.exported_program().graph_module
gm.print_readable()
*_, ones, out = gm.graph.nodes
print(f"Before ExportPass: {ones.format_node()}")
self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt))
self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0)

new_prog = new_prog.transform([ExportPass()], check_ir_validity=False)
new_prog = new_prog.transform(
[ExportPass()],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
gm = new_prog.exported_program().graph_module
gm.print_readable()
*_, ones, out = gm.graph.nodes
Expand Down
20 changes: 14 additions & 6 deletions exir/tests/test_quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def forward(self, x, y):
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
[QuantFusionPass(_fix_node_meta_val=True)],
compile_config=config,
)
# check that we are using functional variant of q/dq/add
FileCheck().check(
Expand Down Expand Up @@ -101,7 +102,8 @@ def forward(self, x, y):
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
[QuantFusionPass(_fix_node_meta_val=True)],
compile_config=config,
)
# check that we are using functional variant of q/dq/add/reshape
# make sure we only have two quant and one dequant since the q/dq around reshape
Expand Down Expand Up @@ -158,7 +160,8 @@ def forward(self, x, y):
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
[QuantFusionPass(_fix_node_meta_val=True)],
compile_config=config,
)
# check that we are using functional variant of q/dq/add/slice
# make sure we only have one quant and one dequant since the q/dq around slice
Expand Down Expand Up @@ -206,7 +209,10 @@ def forward(self, x, y):
config = EdgeCompileConfig(_check_ir_validity=False)
m = to_edge(export(m, example_inputs), compile_config=config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform([QuantFusionPass()], check_ir_validity=False)
m = m.transform(
[QuantFusionPass()],
compile_config=config,
)
# check that we are using functional variant of q/dq/cat
FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
Expand Down Expand Up @@ -302,7 +308,8 @@ def forward(self, indices):
m = to_edge(export(m, example_inputs), compile_config=compile_config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
[QuantFusionPass(_fix_node_meta_val=True)],
compile_config=compile_config,
)
# check that we are using functional variant of q/dq/cat
FileCheck().check(
Expand Down Expand Up @@ -360,7 +367,8 @@ def forward(self, indices):
m = to_edge(export(m, example_inputs), compile_config=compile_config)
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
m = m.transform(
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
[QuantFusionPass(_fix_node_meta_val=True)],
compile_config=compile_config,
)
# check that we are using functional variant of q/dq/cat
FileCheck().check(
Expand Down

0 comments on commit 5c09199

Please sign in to comment.