diff --git a/README.md b/README.md index 830fe735..348edb6f 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,9 @@ We currently support a few LLM models targeting text generation scenarios: ## Installation +For installation on a TPU v4, use the `install-on-TPU-v4.sh` script. Make sure that you DO NOT install pallas or Jetstream as both are targeting TPU v5e! + +Via package: `optimum-tpu` comes with an handy PyPi released package compatible with your classical python dependency management tool. `pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html` diff --git a/install-on-TPU-v4.sh b/install-on-TPU-v4.sh new file mode 100644 index 00000000..de7dec81 --- /dev/null +++ b/install-on-TPU-v4.sh @@ -0,0 +1,24 @@ +sudo apt remove unattended-upgrades +sudo apt update +export PJRT_DEVICE=TPU +export PATH="$HOME/.local/bin:$PATH" +pip install build +pip install --upgrade setuptools +sudo apt install python3.10-venv + +git clone https://github.com/huggingface/optimum-tpu.git + +cd optimum-tpu +make +make build_dist_install_tools +make build_dist + +python -m venv optimum_tpu_env +source optimum_tpu_env/bin/activate + +pip install torch==2.4.0 torch_xla[tpu]==2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html +pip uninstall torchvision # it might insist von 2.4.1 +pip install -e . + +huggingface-cli login +gsutil cp -r gs://entropix/huggingface_hub ~/.cache/huggingface/hub diff --git a/pyproject.toml b/pyproject.toml index b9d4c9d4..0078a7dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,10 +61,11 @@ tests = ["pytest", "safetensors"] quality = ["black", "ruff", "isort"] # Jetstream/Pytorch support is experimental for now, it needs to be installed manually. # Pallas is pulled because it will install a compatible version of jax[tpu]. -jetstream-pt = [ - "jetstream-pt", - "torch-xla[pallas] == 2.4.0" -] +# pallas and jetstream are not supported before v5e. Therefore, comment out on v4 and earlier +#jetstream-pt = [ +# "jetstream-pt", +# "torch-xla[pallas] == 2.4.0" +#] [project.urls] Homepage = "https://hf.co/hardware"