Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to convert a model to parameter only checkpoints (unscanned) on a CPU VM #634

Closed
hosseinsarshar opened this issue Apr 30, 2024 · 2 comments

Comments

@hosseinsarshar
Copy link

hosseinsarshar commented Apr 30, 2024

When converting a checkpoint to param only checkpoint using the maxtext/MaxText/generate_param_only_checkpoint.py script on a CPU device I get this error:

  File "/home/hosseins_google_com/miniconda3/envs/jetstream/lib/python3.10/site-packages/jax/_src/distributed.py", line 62, in initialize
    raise ValueError('coordinator_address should be defined.')
ValueError: coordinator_address should be defined.

Looking closely at the max_utils.maybe_initialize_jax_distributed_system() I noticed that, the GPU condition is met and it calls jax.distributed.initialize().

Changing the conditions in max_utils.maybe_initialize_jax_distributed_system() to force calling initialize_jax_for_cpu results in the following error:

UnboundLocalError: local variable 'coordinator_ip_address' referenced before assignment

It seems that max_utils .get_coordinator_ip_address() doesn't handle a situation in which JAX_COORDINATOR_ADDRESS is None and/or it's unclear how it works on a CPU device.

@hosseinsarshar hosseinsarshar changed the title JAX_COORDINATOR_ADDRESS is None Fail when JAX_COORDINATOR_ADDRESS is None Apr 30, 2024
@hosseinsarshar hosseinsarshar changed the title Fail when JAX_COORDINATOR_ADDRESS is None Conversion fails when JAX_COORDINATOR_ADDRESS is None Apr 30, 2024
@hosseinsarshar
Copy link
Author

Some more context, in this line the return coordinator_ip_address can throw an error if the JAX_COORDINATOR_ADDRESS environment variable is None as coordinator_ip_address is undefined. I think the author wanted to initiate it here but they initiated coordinator_address, instead. Not sure if it's intended, but I wanted raise an issue to let you know that it exists.

@hosseinsarshar hosseinsarshar changed the title Conversion fails when JAX_COORDINATOR_ADDRESS is None How to convert a model to parameter only checkpoints (unscanned) on a CPU VM May 9, 2024
@hosseinsarshar
Copy link
Author

Here is a quick update - the issue was the base image of my CPU VM - I changed it from the default image to the DLVM image (c0-deeplearning-common-cpu-v20240417-debian-11) and it worked. You need to also pass async_checkpointing=false parameter to the generate_param_only_checkpoint.py to bypass the jax.distributed, the job fails without it on the CPU VM regardless of the base image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant