Skip to content

Conversation

@bzantium
Copy link
Contributor

@bzantium bzantium commented Oct 16, 2025

Description

This PR changes the default value of grain_worker_count to -1. This enables Grain's experimental pick_performance_config feature to automatically determine the optimal number of data loading workers, significantly improving out-of-the-box training performance and preventing data pipeline bottlenecks.

Previously, grain_worker_count had a static default that was often suboptimal, requiring users to manually tune this hyperparameter through trial and error to achieve good hardware utilization.

The Problem

When tokenizing raw text data on the fly, the data input pipeline can easily become a bottleneck if the number of parallel workers (grain_worker_count) is too low. This leads to poor accelerator utilization (low TFLOP/s), slow training steps, and a frustrating user experience.

The Solution

By setting the default to -1, we delegate the selection of the worker count to Grain's built-in auto-tuning mechanism. This provides a robust default that adapts to the user's specific hardware and data, ensuring the input pipeline can keep up with the accelerators.

As shown in the performance tests below, this automatic configuration achieves stable, high-throughput training comparable to a manually optimized setting.

grain_worker_count Average TFLOP/s/device Average Time/Step (s) Stability
1 ~29 TFLOP/s ~30.6 s Unstable
2 ~60 TFLOP/s ~13.5 s Highly Unstable
4 ~195 TFLOP/s ~4.3 s Weakly Unstable
8 ~195 TFLOP/s ~4.3 s Stable
-1 (auto) ~195 TFLOP/s ~4.3 s Stable

This change simplifies the user workflow and makes it easier to achieve optimal performance without manual intervention.

Tests

The effectiveness of this change was verified by running the training command below on a v6e-32 pod with different values for grain_worker_count and observing the impact on TFLOP/s and step time.

To reproduce, run the command with grain_worker_count set to 1, 4, 8, and -1 (the new default).

python3 -m MaxText.train src/MaxText/configs/base.yml \
    base_output_directory=${BASE_OUTPUT_DIRECTORY} \
    run_name=$RUN_NAME \
    dataset_type=grain \
    grain_train_files=${DATA_PATH} \
    grain_worker_count=-1 \
    per_device_batch_size=2 \
    model_name=llama3-8b \
    steps=10 \
    max_target_length=8192 \
    enable_checkpointing=false \
    attention=flash \
    dtype=bfloat16

Confirm that grain_worker_count=-1 results in stable and high TFLOP/s (~195 on v6e-32) and low step times (~4.3s), consistent with the performance of a manually tuned optimal value like 4 or 8.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

fixes #2509

@bzantium
Copy link
Contributor Author

This requires grain>=0.2.13 and this can be resolved by #2354

@aireenmei aireenmei self-assigned this Oct 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Automatically Optimize grain_worker_count for Improved Data Loading Performance

2 participants