Skip to content

Commit

Permalink
fix constant tagging in mps backend
Browse files Browse the repository at this point in the history
Summary:
Test with #3399 and this command passes 
```
python -m examples.models.llama2.export_llama -kv --mps
```
Without this diff, it will error out
```
in _verify_exported_program_signature
    raise SpecViolationError(
torch._export.verifier.SpecViolationError: Buffer output getitem_1 does not point to a buffer that exists.
Dict of buffers that are mutated, in order: {'getitem_1': 'layers_0_attention_SDPA_kv_cache_k_cache', 'getitem': 'layers_0_attention_SDPA_kv_cache_v_cache', 'getitem_3': 'layers_1_attention_SDPA_kv_cache_k_cache', 'getitem_2': 'layers_1_attention_SDPA_kv_cache_v_cache', 'getitem_5': 'layers_2_attention_SDPA_kv_cache_k_cache', 'getitem_4': 'layers_2_attention_SDPA_kv_cache_v_cache', 'getitem_7': 'layers_3_attention_SDPA_kv_cache_k_cache', 'getitem_6': 'layers_3_attention_SDPA_kv_cache_v_cache', 'getitem_9': 'layers_4_attention_SDPA_kv_cache_k_cache', 'getitem_8': 'layers_4_attention_SDPA_kv_cache_v_cache'}
Buffer nodes available: []
```
The root cause is that by `is_parameter`, it tags all data including mutable buffers.

Differential Revision: D56941763
  • Loading branch information
cccclai authored and facebook-github-bot committed May 3, 2024
1 parent b9488fe commit 03d39b7
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions backends/apple/mps/partition/mps_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
self.edge_program = edge_program

def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
# Parameters are supported if any of their users are supported
if is_parameter(self.edge_program, node):
return any(
self.is_node_supported(submodules, user) for user in node.users.keys()
)

if node.op != "call_function":
return False

Expand Down Expand Up @@ -132,6 +126,7 @@ def partition(self, edge_program: ExportedProgram) -> PartitionResult:
partitions = self.generate_partitions(edge_program=edge_program)
if self.check_partitions(partitions):
self.tag_nodes(partitions)
# Tag constant data that are used by the supported ops in MPS backend.
tag_constant_data(edge_program)
x = PartitionResult(
tagged_exported_program=edge_program, partition_tags=self.partition_tags
Expand Down

0 comments on commit 03d39b7

Please sign in to comment.