-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#8018: Adding rotary_embedding to ttnn #8616
Conversation
83c1147
to
2c1fce2
Compare
|
||
rotary_embedding(input_tensor: ttnn.Tensor, cos_cache: ttnn.Tensor, sin_cache: ttnn.Tensor, token_index: int, memory_config: MemoryConfig) -> ttnn.Tensor | ||
pt_out = apply_rotary_pos_emb(x, cos_cached, sin_cached, token_idx) | ||
return pt_out | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tt-aho , would you take a close look at the _golden_function here in pytorch that is intended to reproduce the rotary_embedding op we have?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks correct for the case where token_idx is provided
a8bec8d
to
2a69c4d
Compare
No description provided.