You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I can reproduce this bug.
It can run without error when we switch to use iter(train_state.batch_sampler.copy()) and iter(self.batch_sampler.copy()), but train_state.batch_sampler might lose track for the current batch_count and num_consumed_samples_in_epoch because in this case the two values are always zero, which might not ideal when we resume training from previous checkpoints.
The root cause is that train_state.batch_sampler.batch_count and train_state.batch_sampler.num_consumed_samples_in_epoch will increment over iterations. Everytime we re-create generator train_state.batch_sampler_iter = iter(train_state.batch_sampler), it will always start from the last saved batch_count and num_consumed_samples_in_epoch. We want the generator to re-start iteration from batch_num=0, not from the last saved position that already crossed the index boundary. That's why the bug occurs.
So it is clear that as a solution, we need to reset the two values in the StopIteration exception handling.
We fix this bug in the #102
描述该错误
在我的数据不足以跑完整个totalstep的时候,会出现StopIteration的报错。原因是虽然在外面使用了try-except,但是在第512行的next(train_state.batch_sampler_iter仍然会出现越界,原因是train_state.batch_sampler也跟着迭代,所以即使train_state.batch_sampler_iter重新赋值,仍然会越界。
环境信息
其他信息
做了如下修改之后,可以跑通。
The text was updated successfully, but these errors were encountered: