-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a registry for GraphModuleSerializer #126550
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126550
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3acfed8 with merge base 9117779 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
fb2250a
to
1e4b4da
Compare
torch/_export/verifier.py
Outdated
@@ -187,7 +187,7 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]: | |||
) | |||
|
|||
if not isinstance(op, _allowed_op_types()): | |||
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions: | |||
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions and not isinstance(op, torch._export.serde.serialize.CustomOpHandler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the error you saw? I'm confused by the change here.
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions and not isinstance(op, torch._export.serde.serialize.CustomOpHandler): | |
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think without this change, there is no easy way to test the custom op handler. If I inject the custom op in the graph, it will throw error during verification.
torch._export.verifier.SpecViolationError: Operator '<export.test_serialize.TestSerialize.test_export_with_custom_op_serialization.<locals>.FooCustomOp object at 0x7f713ef391b0>' is not an allowed operator type: (<class 'torch._ops.OpOverload'>, <class 'torch._ops.HigherOrderOperator'>)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then sounds like we need to change _allowed_op_types()?
torch/_export/serde/serialize.py
Outdated
if namespace not in _serialization_registry: | ||
_serialization_registry[namespace] = {} | ||
_serialization_registry[namespace][op_name] = serialization_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if namespace not in _serialization_registry: | |
_serialization_registry[namespace] = {} | |
_serialization_registry[namespace][op_name] = serialization_fn | |
_deserialization_registry[namespace] = op_handler |
torch/_export/verifier.py
Outdated
@@ -187,7 +187,7 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]: | |||
) | |||
|
|||
if not isinstance(op, _allowed_op_types()): | |||
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions: | |||
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions and not isinstance(op, torch._export.serde.serialize.CustomOpHandler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then sounds like we need to change _allowed_op_types()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great!
Please address linter issues and other comments before landing this.
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
25302d3
to
b7689d4
Compare
@jiashenC has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@jiashenC has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR adds a registration function and a global registry for GraphModuleSerializer. After this PR, custom serialization methods can be done through registration instead of subclassing for ease of maintenance. ## Changes - Add a test case where it injects custom op to test serialization. - Add custom op handler - Change allowed op for verifier Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com> Pull Request resolved: pytorch#126550 Approved by: https://github.com/zhxchen17
This PR adds a registration function and a global registry for GraphModuleSerializer. After this PR, custom serialization methods can be done through registration instead of subclassing for ease of maintenance.
Changes