Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/source/en/trainer_recipes.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,57 @@ trainer.train(resume_from_checkpoint="out/checkpoint-1000")
When resuming, [`Trainer`] restores the optimizer state, scheduler state, and RNG state.

Checkpoint resuming requires optimizer and scheduler state files in the checkpoint directory. If those files are missing (for example, when `save_only_model=True`), the optimizer restarts from scratch.

### JIT checkpointing

With periodic checkpointing (save_strategy="steps" or "epoch"), you lose any training progress between the last saved checkpoint and an interruption. On shared clusters with preemptible workloads such as [Kueue](https://kueue.sigs.k8s.io/), jobs can be terminated at any time, so that gap can mean hours of wasted compute.

JIT (Just-In-Time) checkpointing closes this gap. When the trainer receives a SIGTERM signal, it saves a checkpoint at the exact point training was interrupted, so you resume with minimal loss of progress. It works alongside periodic checkpointing. Periodic saves guard against crashes and hardware failures, while JIT saves guard against preemption and graceful shutdowns.

Enable it by setting `enable_jit_checkpoint=True` in [`TrainingArguments`].

```py
from transformers import TrainingArguments

training_args = TrainingArguments(
output_dir="your-model",
enable_jit_checkpoint=True,
)
```

When SIGTERM is received, [`Trainer`] waits for the current training step to finish, saves a checkpoint, and stops training gracefully. A sentinel file (`checkpoint-is-incomplete.txt`) is written when the save begins and removed once the checkpoint is fully written. If a checkpoint directory still contains this file, the save was interrupted before completing. [`Trainer`] doesn't check for it automatically, so inspect for it yourself before resuming.

Resume from the JIT checkpoint the same way as any other checkpoint.

```py
trainer.train(resume_from_checkpoint=True)
```

> [!WARNING]
> You must configure your orchestrator to allow enough time for the checkpoint to complete. The default Kubernetes graceful shutdown period is only 30 seconds, which is typically not enough for larger models.

<hfoptions id="orchestrator-grace-period">
<hfoption id="Kubernetes">

Set `terminationGracePeriodSeconds` in your Pod or Job spec. The exact field location varies by trainer (Kubeflow Training Operator, Ray, etc.).

```yaml
spec:
template:
spec:
terminationGracePeriodSeconds: 300
```

</hfoption>
<hfoption id="Slurm">

Use `--signal=TERM@<seconds>` in your sbatch script to send SIGTERM before the job time limit expires.

```bash
#SBATCH --signal=TERM@300
```

</hfoption>
</hfoptions>

Calculate the required grace period as the longest possible training step time plus the checkpoint saving time, plus the 3 second `kill_wait` delay before the checkpoint begins. For example, if a training step takes up to 2 minutes and saving a checkpoint takes 2 minutes, set at least 243 seconds of grace time.
2 changes: 1 addition & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ class TrainingArguments:
Enable Just-In-Time checkpointing on SIGTERM signal for graceful termination on
preemptible workloads. **Important**: Configure your orchestrator's graceful shutdown
period to allow sufficient time. For Kubernetes, set `terminationGracePeriodSeconds`
(default 30s is usually insufficient). For Slurm, use `--signal=USR1@<seconds>`.
(default 30s is usually insufficient). For Slurm, use `--signal=TERM@<seconds>`.
Required grace period ≥ longest iteration time + checkpoint save time.
> Hugging Face Hub Integration
Expand Down
Loading