Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a registry for GraphModuleSerializer #126550

Closed
wants to merge 16 commits into from
Closed

Conversation

jiashenC
Copy link
Contributor

@jiashenC jiashenC commented May 17, 2024

This PR adds a registration function and a global registry for GraphModuleSerializer. After this PR, custom serialization methods can be done through registration instead of subclassing for ease of maintenance.

Changes

  • Add a test case where it injects custom op to test serialization.
  • Add custom op handler
  • Change allowed op for verifier

Copy link

pytorch-bot bot commented May 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126550

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3acfed8 with merge base 9117779 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jiashenC jiashenC force-pushed the hook_for_custom_serializer branch from fb2250a to 1e4b4da Compare May 21, 2024 20:46
@jiashenC jiashenC marked this pull request as ready for review May 21, 2024 20:47
torch/_export/serde/serialize.py Outdated Show resolved Hide resolved
torch/_export/serde/serialize.py Outdated Show resolved Hide resolved
torch/_export/serde/serialize.py Outdated Show resolved Hide resolved
torch/_export/serde/serialize.py Outdated Show resolved Hide resolved
@@ -187,7 +187,7 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]:
)

if not isinstance(op, _allowed_op_types()):
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions and not isinstance(op, torch._export.serde.serialize.CustomOpHandler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the error you saw? I'm confused by the change here.

Suggested change
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions and not isinstance(op, torch._export.serde.serialize.CustomOpHandler):
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think without this change, there is no easy way to test the custom op handler. If I inject the custom op in the graph, it will throw error during verification.

torch._export.verifier.SpecViolationError: Operator '<export.test_serialize.TestSerialize.test_export_with_custom_op_serialization.<locals>.FooCustomOp object at 0x7f713ef391b0>' is not an allowed operator type: (<class 'torch._ops.OpOverload'>, <class 'torch._ops.HigherOrderOperator'>)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then sounds like we need to change _allowed_op_types()?

Comment on lines 2834 to 2836
if namespace not in _serialization_registry:
_serialization_registry[namespace] = {}
_serialization_registry[namespace][op_name] = serialization_fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if namespace not in _serialization_registry:
_serialization_registry[namespace] = {}
_serialization_registry[namespace][op_name] = serialization_fn
_deserialization_registry[namespace] = op_handler

torch/_export/serde/serialize.py Outdated Show resolved Hide resolved
torch/_export/serde/serialize.py Outdated Show resolved Hide resolved
torch/_export/serde/serialize.py Outdated Show resolved Hide resolved
@@ -187,7 +187,7 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]:
)

if not isinstance(op, _allowed_op_types()):
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions and not isinstance(op, torch._export.serde.serialize.CustomOpHandler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then sounds like we need to change _allowed_op_types()?

@jiashenC jiashenC requested a review from zhxchen17 May 23, 2024 23:28
Copy link
Contributor

@zhxchen17 zhxchen17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great!

Please address linter issues and other comments before landing this.

torch/_export/serde/serialize.py Outdated Show resolved Hide resolved
jiashenC and others added 15 commits May 24, 2024 15:19
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
@jiashenC jiashenC force-pushed the hook_for_custom_serializer branch from 25302d3 to b7689d4 Compare May 24, 2024 22:30
@facebook-github-bot
Copy link
Contributor

@jiashenC has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jiashenC has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jiashenC jiashenC added the topic: not user facing topic category label May 28, 2024
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Aidyn-A pushed a commit to tinglvv/pytorch that referenced this pull request May 30, 2024
This PR adds a registration function and a global registry for GraphModuleSerializer. After this PR, custom serialization methods can be done through registration instead of subclassing for ease of maintenance.

## Changes
- Add a test case where it injects custom op to test serialization.
- Add custom op handler
- Change allowed op for verifier
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Pull Request resolved: pytorch#126550
Approved by: https://github.com/zhxchen17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants