feat(data): Default grain_worker_count to -1 for automatic performance tuning
#2510
+24
−8
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR changes the default value of
grain_worker_countto-1. This enables Grain's experimentalpick_performance_configfeature to automatically determine the optimal number of data loading workers, significantly improving out-of-the-box training performance and preventing data pipeline bottlenecks.Previously,
grain_worker_counthad a static default that was often suboptimal, requiring users to manually tune this hyperparameter through trial and error to achieve good hardware utilization.The Problem
When tokenizing raw text data on the fly, the data input pipeline can easily become a bottleneck if the number of parallel workers (
grain_worker_count) is too low. This leads to poor accelerator utilization (low TFLOP/s), slow training steps, and a frustrating user experience.The Solution
By setting the default to
-1, we delegate the selection of the worker count to Grain's built-in auto-tuning mechanism. This provides a robust default that adapts to the user's specific hardware and data, ensuring the input pipeline can keep up with the accelerators.As shown in the performance tests below, this automatic configuration achieves stable, high-throughput training comparable to a manually optimized setting.
grain_worker_countThis change simplifies the user workflow and makes it easier to achieve optimal performance without manual intervention.
Tests
The effectiveness of this change was verified by running the training command below on a
v6e-32pod with different values forgrain_worker_countand observing the impact on TFLOP/s and step time.To reproduce, run the command with
grain_worker_countset to1,4,8, and-1(the new default).python3 -m MaxText.train src/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ run_name=$RUN_NAME \ dataset_type=grain \ grain_train_files=${DATA_PATH} \ grain_worker_count=-1 \ per_device_batch_size=2 \ model_name=llama3-8b \ steps=10 \ max_target_length=8192 \ enable_checkpointing=false \ attention=flash \ dtype=bfloat16Confirm that
grain_worker_count=-1results in stable and high TFLOP/s (~195 on v6e-32) and low step times (~4.3s), consistent with the performance of a manually tuned optimal value like 4 or 8.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.fixes #2509