Skip to content

Conversation

@SamuelMarks
Copy link
Collaborator

@SamuelMarks SamuelMarks commented Jun 15, 2025

Description

Background

MaxText is currently engineered around pyconfig. Pyconfig—https://pypi.org/project/pyconfig/—was last updated in 2017 and has 20k monthly downloads.

Pydantic—https://pypi.org/project/pydantic/—is constantly updated and has hundreds of millions of monthly downloads.

Pydantic is the most widely used data validation library for Python.

https://docs.pydantic.dev

Summary

The TL;DR version is:

  • pydantic has become basically the standard for configuration formats, specifying the inputs and outputs (e.g., for REST APIs in FastAPI framework)
  • pydantic links in well with the Python type-checker so you can oft find errors before runtime
  • pydantic errors are clean and clear
  • with my python compiler, you can have clean CLIs with --help and GUIs and generate SQL models (SQLAlchemy) as desired
  • SDK documentation becomes much cleaner (which is good in preparation for the pypi release + hosted doc pages)

Migration

Proposed changes to the MaxText codebase:

  1. Completely remove the pyconfig dependency
    1. maybe not immediately so there’s a chance for people to easily migration of their existing setups, e.g., with new functions from_pyconfig_to_pydantic
  2. Migrate all examples, and codebase occurrences with new pydantic types
  3. Put pydantic types on a computer+human understandable hierarchy, e.g.:
    1. One global types.py, or
    2. One types.py per config occurrence (e.g., one per module if each module has a different config)
  4. Create new CLI that uses common CLI syntax (e.g., this can be automatically created using my Python compiler https://github.com/offscale/cdd-python)
  5. Migrate all shell scripts and docs to use this new CLI
    1. TBD: remove the shell scripts in favour of Python SDK usage.

Tests

CI and manual:

$ bash ./dependencies/scripts/docker_build_dependency_image.sh DEVICE='tpu' MODE='nightly'
$ bash ./dependencies/scripts/docker_build_dependency_image.sh DEVICE='tpu' MODE='stable'
$ export MODEL_NAME='llama3_1_70b_8192_synthetic' \
         PROJECT="${GOOGLE_CLOUD_PROJECT?}" \
         ZONE="${GOOGLE_CLOUD_ZONE?}" \
         CLUSTER_NAME="${GOOGLE_CLOUD_CLUSTER_NAME?}" \
         OUTPUT_DIR="${GOOGLE_CLOUD_BUCKET?}" \
         BASE_OUTPUT_DIR="${GOOGLE_CLOUD_BUCKET?}"'/output/' \
         DATASET_PATH="${GOOGLE_CLOUD_BUCKET?}"'/' \
         WORKLOAD='job_name_goes_here'

# Try running every model on TPU VM.
# Once generated, one can loop through `local_runs.txt` on TPU VM
# and not loop through models until either new ones are added or 
# you're on a more powerful TPU VM.
$ for model_name in 'default' 'llama2-7b' 'llama2-13b' 'llama2-70b' 'llama3-8b' 'llama3-70b' 'llama3.1-8b' \
                    'llama3.1-70b' 'llama3.1-405b' 'llama3.3-70b' 'mistral-7b' 'mixtral-8x7b' \
                    'mixtral-8x22b' 'deepseek2-16b' 'deepseek2-236b' 'deepseek3-671b' \
                    'deepseek3-test' 'deepseek3-tiny' 'kimi-k2-1t' 'gemma-7b' 'gemma-2b' 'gemma2-2b' \
                    'gemma2-9b' 'gemma2-27b' 'gemma3-4b' 'gemma3-12b' 'gemma3-27b' 'qwen3-0.6b' \
                    'qwen3-4b' 'qwen3-4b-thinking-2507' 'qwen3-8b' 'qwen3-14b' 'qwen3-32b' \
                    'qwen3-235b-a22b' 'qwen3-30b-a3b' 'qwen3-480b-a35b' 'qwen3-next-80b-a3b' \
                    'gpt3-175b' 'gpt3-22b' 'gpt3-6b' 'gpt3-52k' 'gpt-oss-20b' 'gpt-oss-120b' \
                    'llama4-17b-16e' 'llama4-17b-128e'; do
  python3 -m MaxText.train MaxText/configs/base.yml \
      run_name="${USER}"'_'"${model_name}"'_002' \
      base_output_directory="${OUTPUT_DIR?}" \
      dataset_type='synthetic' \
      steps='10' \
      model_name="$model_name" && \
  printf '%s\n' "$model_name" >> 'successful_local_runs.txt' || \
  printf '%s\n' "$model_name" >> 'failed_local_runs.txt'
done

$ wc -l 'successful_local_runs.txt'
11 successful_local_runs.txt

$ cat 'successful_local_runs.txt'
default
llama2-7b
mistral-7b
deepseek3-tiny
gemma-2b
gemma2-2b
qwen3-0.6b
qwen3-4b
qwen3-4b-thinking-2507
gpt3-6b
gpt3-52k

$ printf -v command 'python3 -m MaxText.train MaxText/configs/base.yml base_output_directory='"'"'%s'"'"' dataset_path='"'"'%s'"'"' steps='"'"'%d'"'"' per_device_batch_size='"'"'%d'"'" \
  "${BASE_OUTPUT_DIR?}" "${DATASET_PATH?}" '100' '1'

$ xpk workload create \
      --base-docker-image 'maxtext_base_image' \
      --zone "${ZONE?}" \
      --cluster "${CLUSTER_NAME?}" \
      --workload "${WORKLOAD?}" \
      --tpu-type='v6e-256' \
      --num-slices='1' \
      --command "${command?}"

# Try running every model on TPU cluster.
# Once generated, one can loop through `successful_cluster_runs.txt`
# and not loop through models until either new ones are added or 
# you're on different cluster hardware.
$ for model_name in 'default_basic_1' 'default_32' 'default_64' 'default_128' 'default_256' \
                    'default_512' 'gpt_3_175b' 'gpt_3_175b_bf16' 'llama2_7b_4096' \
                    'llama2_70b_4096' 'llama2_70b_4096_synthetic' 'llama2_70b_4096_sc' \
                    'llama2_70b_4096_sc_real_data_tfds' 'llama2_70b_4096_sc_real_data_grain' \
                    'llama2_70b_4096_sc_real_data_grain_checkpoint' 'llama2_70b_4096_rd_lr' \
                    'llama3_8b_8192' 'llama3_70b_8192' 'llama3_1_405b_8192_fsdp_dcn' \
                    'llama3_1_405b_8192_pure_fsdp_ici' 'llama3_1_8b_8192' 'llama3_1_8b_8192_bs5' \
                    'llama3_1_8b_8192_no_collective_matmul' 'llama3_1_70b_8192' 'llama3_1_70b_8192_bs2' \
                    'llama3_1_70b_8192_bs2_bfloat16_no_collective_matmul' 'llama3_1_70b_8192_bs4' \
                    'llama3_1_70b_8192_synthetic' 'llama3_1_70b_8192_rd_grain' \
                    'llama3_1_70b_8192_synthetic_ckpt' 'llama3_1_70b_8192_rd_ckpt_grain' \
                    'llama3_1_70b_8192_pw_lr_rd' 'llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds' \
                    'llama3_1_70b_8192_synth' 'llama3_1_70b_129024' 'mistral_7b' 'mixtral_8x7b_dropless' \
                    'mixtral_8x7b_dropped' 'mixtral_8x7b_dropped_int8' 'mixtral_8x22b_dropped' \
                    'deepseek_v3_ep16' 'gemma2_9b_8192' 'gemma2_27b_8192' \
                    'gemma3_12b_32768_v6e256' 'gemma3_12b_32768_2x_v6e256' \
                    'gemma3_12b_32768_4x_v6e256' 'llama3_1_70b_131072' 'custom_moe_700b' \
                    'llama3_1_405b_8192_v5p_1024' 'deepseek_v3_ep_256_v5p_512' \
                    'llama4_scout_dropless_v5p_256' 'llama4_maverick_dropless_v5p_256' 'llama2_70b_v5p_128' \
                    'llama2_7b_v5p_128' 'gpt_3_175b_v5p_128' 'gpt_3_175b_v5p_128_sc' 'deepseek3_671b_v5p_1024' \
                    'default_16b_v5e_256' 'default_32b_v5e_256' 'default_64b_v5e_256' 'default_128b_v5e_256' \
                    'gpt_3_175b_v5e_256' 'llama2_7b_v5e_256' 'llama2_13b_v5e_256' 'llama2_70b_v5e_256' \
                    'llama3_1_8b_8192_v5e_256' 'deepseek_v3_ep_256_v5p_512_c4mlperf'; do
  python3 -m benchmarks.benchmark_runner xpk \
      --base_docker_image='maxtext_base_image' \
      --project="${PROJECT?}" \
      --zone="${ZONE?}" \
      --cluster_name="${CLUSTER_NAME?}" \
      --device_type='v6e-256' \
      --num_slices='1' \
      --base_output_directory="${OUTPUT_DIR?}" \
      --model_name="$model_name" && \
  printf '%s\n' "$model_name" >> 'successful_cluster_runs.txt' || \
  printf '%s\n' "$model_name" >> 'failed_cluster_runs.txt'
done

$ wc -l 'successful_cluster_runs.txt'
[TBD]

$ cat 'successful_cluster_runs.txt'
[TBD]

TL;DR version, these worked:

  • default
  • llama2-7b
  • mistral-7b
  • deepseek3-tiny
  • gemma-2b
  • gemma2-2b
  • qwen3-0.6b
  • qwen3-4b
  • qwen3-4b-thinking-2507
  • gpt3-6b
  • gpt3-52k
  • default_basic_1
  • default_32
  • default_64
  • default_128
  • default_256
  • default_512
  • gpt_3_175b
  • gpt_3_175b_bf16
  • llama2_7b_4096
  • llama2_70b_4096
  • llama2_70b_4096_synthetic
  • llama2_70b_4096_sc
  • llama2_70b_4096_sc_real_data_tfds
  • llama2_70b_4096_sc_real_data_grain
  • llama2_70b_4096_sc_real_data_grain_checkpoint
  • llama2_70b_4096_rd_lr
  • llama3_8b_8192
  • llama3_70b_8192
  • llama3_1_405b_8192_fsdp_dcn
  • llama3_1_405b_8192_pure_fsdp_ici
  • llama3_1_8b_8192
  • llama3_1_8b_8192_bs5
  • llama3_1_8b_8192_no_collective_matmul
  • llama3_1_70b_8192
  • llama3_1_70b_8192_bs2
  • llama3_1_70b_8192_bs2_bfloat16_no_collective_matmul
  • llama3_1_70b_8192_bs4
  • llama3_1_70b_8192_synthetic
  • llama3_1_70b_8192_rd_grain
  • llama3_1_70b_8192_synthetic_ckpt
  • llama3_1_70b_8192_rd_ckpt_grain
  • llama3_1_70b_8192_pw_lr_rd
  • llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds
  • llama3_1_70b_8192_synth
  • llama3_1_70b_129024
  • mistral_7b
  • mixtral_8x7b_dropless
  • mixtral_8x7b_dropped
  • mixtral_8x7b_dropped_int8
  • mixtral_8x22b_dropped
  • deepseek_v3_ep16
  • gemma2_9b_8192
  • gemma2_27b_8192
  • gemma3_12b_32768_v6e256
  • gemma3_12b_32768_2x_v6e256
  • gemma3_12b_32768_4x_v6e256
  • llama3_1_70b_131072
  • custom_moe_700b
  • llama3_1_405b_8192_v5p_1024
  • deepseek_v3_ep_256_v5p_512
  • llama4_scout_dropless_v5p_256
  • llama4_maverick_dropless_v5p_256
  • llama2_70b_v5p_128
  • llama2_7b_v5p_128
  • gpt_3_175b_v5p_128
  • gpt_3_175b_v5p_128_sc
  • deepseek3_671b_v5p_1024
  • default_16b_v5e_256
  • default_32b_v5e_256
  • default_64b_v5e_256
  • default_128b_v5e_256
  • gpt_3_175b_v5e_256
  • llama2_7b_v5e_256
  • llama2_13b_v5e_256
  • llama2_70b_v5e_256
  • llama3_1_8b_8192_v5e_256
  • deepseek_v3_ep_256_v5p_512_c4mlperf

Manually ran this to test also:

python3 -m MaxText.decode MaxText/configs/base.yml \
    model_name=llama2-7b \
    tokenizer_path=assets/tokenizer_llama3.tiktoken \
    tokenizer_type=tiktoken \
    scan_layers=false \
    per_device_batch_size=1 \
    ici_fsdp_parallelism=1 \
    ici_autoregressive_parallelism=-1 \
    max_prefill_predict_length=128 \
    max_target_length=256 \
    prompt="I love to" \
    attention=dot_product

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending line lengths which make it hard to read

@SamuelMarks SamuelMarks requested a review from NuojCheng as a code owner August 1, 2025 16:23
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do the file name letters stand for? types_j, types_g, etc.

@SamuelMarks
Copy link
Collaborator Author

What do the file name letters stand for? types_j, types_g, etc.

@bvandermoon Oh ignore that. Each is a different attempt (in lexicographical order). All will be removed with a singular types.py to take its place when I rebase this PR to 1 commit.

@SamuelMarks SamuelMarks force-pushed the pydantic branch 2 times, most recently from 5683f77 to 2f4bd2b Compare October 17, 2025 19:32
Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@SamuelMarks SamuelMarks force-pushed the pydantic branch 4 times, most recently from 11fec39 to 08848be Compare October 24, 2025 17:47
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM, please also add the additional benchmark runner tests you mentioned offline in the PR description

@SamuelMarks SamuelMarks force-pushed the pydantic branch 2 times, most recently from e3e08dd to 86d8519 Compare October 29, 2025 19:20
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

…configuration files in MaxText ; [src/MaxText/pyconfig.py] New temporary wrapper to not break existing API ; [src/MaxText/pyconfig_og.py] Move original version here ; [src/MaxText/configs/__init__.py] Make this a module ; [tests/pyconfig_test.py] Import from og pyconfig ; [*requirements*.txt] Add pydantic requirement
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @SamuelMarks for the offline testing. Discussed offline, it looks like the base.yml configs are being picked up, but the model-specific configs are not. Temporarily removing the Pull Ready tag so we can sort this out

@bvandermoon bvandermoon self-requested a review November 4, 2025 23:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants