You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When converting a checkpoint to param only checkpoint using the maxtext/MaxText/generate_param_only_checkpoint.py script on a CPU device I get this error:
File "/home/hosseins_google_com/miniconda3/envs/jetstream/lib/python3.10/site-packages/jax/_src/distributed.py", line 62, in initialize
raise ValueError('coordinator_address should be defined.')
ValueError: coordinator_address should be defined.
Looking closely at the max_utils.maybe_initialize_jax_distributed_system() I noticed that, the GPU condition is met and it calls jax.distributed.initialize().
Changing the conditions in max_utils.maybe_initialize_jax_distributed_system() to force calling initialize_jax_for_cpu results in the following error:
UnboundLocalError: local variable 'coordinator_ip_address' referenced before assignment
It seems that max_utils .get_coordinator_ip_address() doesn't handle a situation in which JAX_COORDINATOR_ADDRESS is None and/or it's unclear how it works on a CPU device.
The text was updated successfully, but these errors were encountered:
hosseinsarshar
changed the title
JAX_COORDINATOR_ADDRESS is None
Fail when JAX_COORDINATOR_ADDRESS is None
Apr 30, 2024
hosseinsarshar
changed the title
Fail when JAX_COORDINATOR_ADDRESS is None
Conversion fails when JAX_COORDINATOR_ADDRESS is None
Apr 30, 2024
Some more context, in this line the return coordinator_ip_address can throw an error if the JAX_COORDINATOR_ADDRESS environment variable is None as coordinator_ip_address is undefined. I think the author wanted to initiate it here but they initiated coordinator_address, instead. Not sure if it's intended, but I wanted raise an issue to let you know that it exists.
hosseinsarshar
changed the title
Conversion fails when JAX_COORDINATOR_ADDRESS is None
How to convert a model to parameter only checkpoints (unscanned) on a CPU VM
May 9, 2024
Here is a quick update - the issue was the base image of my CPU VM - I changed it from the default image to the DLVM image (c0-deeplearning-common-cpu-v20240417-debian-11) and it worked. You need to also pass async_checkpointing=false parameter to the generate_param_only_checkpoint.py to bypass the jax.distributed, the job fails without it on the CPU VM regardless of the base image
When converting a checkpoint to param only checkpoint using the
maxtext/MaxText/generate_param_only_checkpoint.py
script on a CPU device I get this error:Looking closely at the
max_utils.maybe_initialize_jax_distributed_system()
I noticed that, the GPU condition is met and it callsjax.distributed.initialize()
.Changing the conditions in
max_utils.maybe_initialize_jax_distributed_system()
to force callinginitialize_jax_for_cpu
results in the following error:It seems that
max_utils .get_coordinator_ip_address()
doesn't handle a situation in whichJAX_COORDINATOR_ADDRESS
isNone
and/or it's unclear how it works on a CPU device.The text was updated successfully, but these errors were encountered: