Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: Gradient Accumulation #607

Open
thiagolaitz opened this issue Apr 19, 2024 · 4 comments
Open

Question: Gradient Accumulation #607

thiagolaitz opened this issue Apr 19, 2024 · 4 comments

Comments

@thiagolaitz
Copy link

Hello, does it support gradient accumulation or microbatches like those in the T5X repository? I didn't find a parameter for this in base.yml, maybe I just didn't see it? Thank you!

@rwitten
Copy link
Collaborator

rwitten commented Apr 22, 2024

We don't support that out of the box. We've found that tuning LR to be smaller is a better approach.

What is your use case?

@thiagolaitz
Copy link
Author

I'm training bigger models than before, so I can't use the same batch size on the same TPU. Got any recommended ablation studies on using gradient accumulation versus lowering the LR? Also, if I skip gradient accumulation, should I just linearly reduce the LR based on the batch size? Thanks!

@rodrigo-f-nogueira
Copy link

+1
Adding another use case: considering that the availability of TPUs vary, we encounter situations where we initially train a model with a v4-128 TPU but later need to replicate the experiment with a v4-64 TPU, which has less memory. Thus, we must use gradient accumulation to maintain consistency in the results.

@hxssgaa
Copy link

hxssgaa commented May 10, 2024

Simply add following code after allocation of optimizer in optimizers.py support the gradient accumulation:

if config.accumulate_gradient_steps > 1:
    optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants