New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add HF input pipeline #592
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This CR doesn't feel done -- switching to HF seems great but loading on a single process isn't an option and MaxText's data loading is feeling increasingly bloated and unattractive. Like shouldn't we able to use the same tokenizers we always do and just use HF for loading the raw data?
However, not merging this also doesn't feel like a good option. You'll face rebases, etc.
I think we should be decoupling development from MaxText. You should figure out HF datasets to the point where you have a polished thing that allows us to remove TFDS etc. At that point, we merge a CR to disable TFDS and enable HF.
@@ -155,19 +155,30 @@ ici_autoregressive_parallelism: 1 | |||
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1. | |||
num_slices: -1 | |||
|
|||
# Tokenizer | |||
vocab_size: 32_000 # powers of 2 for sharding | |||
# When using HF pipeline, set tokenizer_path to a tokenizer in HF hub, e.g. "google-t5/t5-large", or a local folder contains HF tokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hf pipeline isn't compatible with all of our tokenizers? Why not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current tokenizers in sp model format are not supported by HF tokenizer loader, AutoTokenizer. HF uses tokenizer in json format.
I've added more details in this doc (go/maxtext-input-pipeline) - Issues during implementation and solutions - tokenizer
- name: Test train.py with HF c4 | ||
run: | | ||
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ | ||
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet dataset_name=parquet dataset_type=hf steps=2 tokenizer_path=google-t5/t5-large enable_checkpointing=false' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this command. Is this loading a single parquet file or all parquent files in that directory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is only using a single file. Because using the whole dataset adds an overhead to resolve the whole dataset, which is too much for unit tests.
I've added more details in this doc (go/maxtext-input-pipeline) - Issues during implementation and solutions - HF hub or local dataset?
# Dataset | ||
# Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/" | ||
# For TFDS pipeline, Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm very concerned by the level of complexity here -- it feels very un-MaxTexty and needlessly confusing for our reference implementation. It is critical that we take action to make this simple.
batch_size % global_mesh.size == 0 | ||
), 'Batch size should be divisible number of global devices.' | ||
|
||
dataset = split_dataset_by_node(dataset, world_size=jax.process_count(), rank=jax.process_index()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@khatwanimohit make sure to sync with Mohit here, this is no longer right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implemention does not support "expansion_factor_real_data" yet, as I noted in the doc input_pipeline.md. I will add support for expansion_factor_real_data in later PR. (Currently "expansion_factor_real_data" is ignored when dataset_type=hf). This should work for now, @khatwanimohit please also take a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a reason why it can't be done in this PR. Its a small change. You'll just have to pass in
dataloading_host_index = process_indices.index(jax.process_index()),
dataloading_host_count = len(process_indices)
to _hf_data_processing.preprocess_dataset
|
||
if pack_examples: | ||
pack_op = _hf_operations.PackAndBatchOperation( | ||
batch_size=batch_size // jax.process_count(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably this is also no longer right, talk to mohit!
@@ -37,11 +37,17 @@ then | |||
CMD_DATA=" dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1" | |||
fi | |||
|
|||
if [ "$DATASET_TYPE" == "hf" ] | |||
then | |||
gsutil cp -r gs://maxtext-dataset/hf/llama2-tokenizer assets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we communicate via M_ flags here?
|
||
1. Streaming data directly from HuggingFace Hub may be impacted by the traffic of the server. During peak hours you may encounter "504 Server Error: Gateway Time-out". | ||
2. The current implementation uses only single process to retrieve data. It may slows down the training if you train a small model on a big slice with big batch (e.g. 1B model on a v4-128 with a batch of 1M tokens). We are working on improving the perf. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unacceptable to me.
No description provided.