Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print Time More Accurately In MaxText #632

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 22 additions & 8 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):

_buffered_step = None
_buffered_metrics = None
_last_buffer_time = None
_buffered_lr = None


def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
def write_metrics(writer, per_device_tflops, lr, local_metrics_file, running_gcs_metrics, metrics, step, config):
"""Entry point for all metrics writing in Train's Main.
TODO: would be better as a Class in the future (that initialized all state!)
Expand All @@ -112,11 +114,23 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
global _buffered_step, _buffered_metrics
global _buffered_step, _buffered_metrics, _last_buffer_time, _buffered_lr


if _buffered_metrics is not None:
jax.block_until_ready(_buffered_metrics)

next_buffer_time = datetime.datetime.now()

if _buffered_metrics is not None:
if _buffered_step is None:
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
if _last_buffer_time is None:
raise ValueError(f"When writing metrics, {_last_buffer_time=} was none")
if _buffered_lr is None:
raise ValueError(f"When writing metrics, {_buffered_lr=} was none")

record_scalar_metrics(_buffered_metrics, (next_buffer_time - _last_buffer_time), per_device_tflops, _buffered_lr)
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)

if config.metrics_file:
Expand All @@ -127,6 +141,8 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step

_buffered_step = step
_buffered_metrics = metrics
_last_buffer_time = next_buffer_time
_buffered_lr = lr


def write_metrics_to_tensorboard(writer, metrics, step, config):
Expand Down Expand Up @@ -486,7 +502,6 @@ def train_loop(config, state=None):
last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1)

example_batch = None
last_step_completion = datetime.datetime.now()

for step in np.arange(start_step, config.steps):
if step == first_profiling_step:
Expand All @@ -500,9 +515,6 @@ def train_loop(config, state=None):
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, nextrng)

new_time = datetime.datetime.now()
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step))
last_step_completion = new_time

if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, step, state, config.dataset_type, data_iterator):
Expand All @@ -513,7 +525,8 @@ def train_loop(config, state=None):
checkpoint_manager.wait_until_finished()
sys.exit()

write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config)
write_metrics(writer, per_device_tflops, learning_rate_schedule(step),local_metrics_file, running_gcs_metrics, metrics,
step, config)

if config.eval_interval > 0 and step > start_step and step % config.eval_interval == 0:
assert eval_data_iterator
Expand All @@ -535,7 +548,8 @@ def train_loop(config, state=None):

if checkpoint_manager is not None:
checkpoint_manager.wait_until_finished()
write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics
write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1,
config, step, config) # final step metrics
max_utils.close_summary_writer(writer)
record_goodput(recorder, config, job_end=True)
return state
Expand Down