-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-process batch size not calculated correctly #19726
Comments
Hi @natbprice , Thanks for reporting. I have tested the code snippet and reproduced the reported behaviour. Attached gist for reference. |
Thanks for the report and the investigation. After looking into it in details, I came to the conclusion that this works as expected. I saw your proposed fix: if isinstance(ds, _MapDataset) or isinstance(ds, _ParallelMapDataset):
return ds._input_dataset._batch_size But that's the batch size of the input dataset. The issue is that there is no constraint on what the function passed to Now, why does this only happen when using multi-process distribution? That's because Keras is able to train with an unknown batch size in the normal case and only tries to determine the batch size if distribution is turned on. What's the fix? Well, the standard pattern I've seen used is to batch last, after ds = tf.data.Dataset.from_tensor_slices((inputs, labels))
ds = ds.map(lambda x,y: (x,y))
ds = ds.batch(16) Let me know if you have further questions. |
@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround. Sorry, if I created extra work. I guess I should have not opened related issue here. |
Yes, I think the fix should be in keras-nlp. One should simply apply |
@hertschuh if you don't mind following up in keras-nlp, that would be great! I think I understand the solution you are proposing, but I can't quite figure out the best way for keras-nlp API to function. In particular, it seems like there are several combinations of (1) distribution strategy, (2) input types (e.g., tf.data.Dataset, NumPy arrays), and (3) batching (e.g., pre-batched dataset, explicit Currently, in |
Describe the bug
I opened a related issue in keras-nlp, but I believe the issue is likely best addressed in keras. See related issue: keras-team/keras-nlp#1630
Currently, the batch size is not calculated correctly when performing multi-process distributed training with JAX backend if the dataset has been pre-processed with a mapping function.
To Reproduce
See https://colab.research.google.com/drive/1IxVNDcNoIK4SiX2wuDQKfqR_6Or9P40I?usp=sharing
Expected behavior
A batched tf.data.Dataset() object is recognized as being batched.
Would you like to help us fix it?
I would like to try to fix this if it is not too complex. Maybe we can just replace call to
tensorflow.python.data.experimental.ops.distribute.compute_batch_size()
withdataset._input_dataset._batch_size
?The text was updated successfully, but these errors were encountered: