Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attend_dtype not used #531

Open
zhixuan-lin opened this issue Mar 18, 2024 · 1 comment
Open

attend_dtype not used #531

zhixuan-lin opened this issue Mar 18, 2024 · 1 comment
Assignees

Comments

@zhixuan-lin
Copy link

Here it seems that the hard-coded bfloat16 is used instead of attend_dtype. Also query is not cast. I guess the correct behavior should be casting both query and self.embedding to attend_dtype?

@rwitten
Copy link
Collaborator

rwitten commented Mar 19, 2024

yes weird. @khatwanimohit can you take a look? I'm not sure what this is meant to represent? And the upstream flag is also kind of weird given that it is orphaned?

attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability

I think we should figure out if
(a) does doing the dot in f32 help convergence (using the 1B runs)?
(b) does @ZhiyuLi-goog/MLPerf care?
(c) what does Anselm Levskaya think

We should make the code consistent and as simple as possible. Also, why is our pylint/pytype not raising alarms on this, unused vars are bad?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants