Skip to content

Commit eac6a62

Browse files
committed
Added APOLLO optimizer integration
1 parent 0de15c9 commit eac6a62

File tree

7 files changed

+428
-1
lines changed

7 files changed

+428
-1
lines changed

docs/source/en/trainer.md

+133
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,139 @@ trainer.train()
445445

446446
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
447447

448+
449+
### APOLLO
450+
451+
Approximated Gradient Scaling for Memory Efficient LLM Optimization (APOLLO) is a memory-efficient low-rank training strategy that allows full-parameter learning for both pre-training and fine-tuning, while maintaining AdamW-level performance with SGD-like memory efficiency.
452+
453+
* **Ultra-low rank efficiency** → Requires much lower rank than GaLore—even rank 1 (APOLLO-Mini) suffices.
454+
* **No expensive SVD computations** → Unlike GaLore, APOLLO leverages random projection, avoiding training stalls.
455+
456+
First make sure to install APOLLO from its official repository:
457+
458+
```bash
459+
pip install apollo-torch
460+
```
461+
462+
Then simply add one of `["apollo_adamw"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):
463+
464+
```python
465+
import torch
466+
import datasets
467+
import trl
468+
469+
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
470+
471+
train_dataset = datasets.load_dataset('imdb', split='train')
472+
473+
args = TrainingArguments(
474+
output_dir="./test-apollo",
475+
max_steps=100,
476+
per_device_train_batch_size=2,
477+
optim="apollo_adamw",
478+
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
479+
)
480+
481+
model_id = "google/gemma-2b"
482+
483+
config = AutoConfig.from_pretrained(model_id)
484+
485+
tokenizer = AutoTokenizer.from_pretrained(model_id)
486+
model = AutoModelForCausalLM.from_config(config).to(0)
487+
488+
trainer = trl.SFTTrainer(
489+
model=model,
490+
args=args,
491+
train_dataset=train_dataset,
492+
dataset_text_field='text',
493+
max_seq_length=512,
494+
)
495+
496+
trainer.train()
497+
```
498+
499+
To pass extra arguments supported by APOLLO, you should pass correctly `optim_args`, for example:
500+
501+
```python
502+
import torch
503+
import datasets
504+
import trl
505+
506+
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
507+
508+
train_dataset = datasets.load_dataset('imdb', split='train')
509+
510+
args = TrainingArguments(
511+
output_dir="./test-galore",
512+
max_steps=100,
513+
per_device_train_batch_size=2,
514+
optim="galore_adamw",
515+
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
516+
optim_args="proj=random,scale_type=tensor,rank=128,update_proj_gap=100,scale=1.0",
517+
518+
)
519+
520+
model_id = "google/gemma-2b"
521+
522+
config = AutoConfig.from_pretrained(model_id)
523+
524+
tokenizer = AutoTokenizer.from_pretrained(model_id)
525+
model = AutoModelForCausalLM.from_config(config).to(0)
526+
527+
trainer = trl.SFTTrainer(
528+
model=model,
529+
args=args,
530+
train_dataset=train_dataset,
531+
dataset_text_field='text',
532+
max_seq_length=512,
533+
)
534+
535+
trainer.train()
536+
```
537+
538+
Currently only Linear layers are considered to use the APOLLO optimizers, while the remaining modueles are still using AdamW.
539+
540+
You can read more about the method in the [original repository](https://github.com/zhuhanqing/APOLLO) or the [paper](https://arxiv.org/abs/2412.05270).
541+
542+
543+
You can also perform layer-wise APOLLO by simply post-pending the optimizer name with `layerwise` like below:
544+
545+
```python
546+
import torch
547+
import datasets
548+
import trl
549+
550+
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
551+
552+
train_dataset = datasets.load_dataset('imdb', split='train')
553+
554+
args = TrainingArguments(
555+
output_dir="./test-apollo",
556+
max_steps=100,
557+
per_device_train_batch_size=2,
558+
optim="apollo_adamw_layerwise",
559+
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
560+
)
561+
562+
model_id = "google/gemma-2b"
563+
564+
config = AutoConfig.from_pretrained(model_id)
565+
566+
tokenizer = AutoTokenizer.from_pretrained(model_id)
567+
model = AutoModelForCausalLM.from_config(config).to(0)
568+
569+
trainer = trl.SFTTrainer(
570+
model=model,
571+
args=args,
572+
train_dataset=train_dataset,
573+
dataset_text_field='text',
574+
max_seq_length=512,
575+
)
576+
577+
trainer.train()
578+
```
579+
580+
448581
### LOMO optimizer
449582

450583
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).

src/transformers/testing_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
GGUF_MIN_VERSION,
6363
is_accelerate_available,
6464
is_apex_available,
65+
is_apollo_torch_available,
6566
is_aqlm_available,
6667
is_auto_awq_available,
6768
is_auto_gptq_available,
@@ -403,6 +404,14 @@ def require_galore_torch(test_case):
403404
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
404405

405406

407+
def require_apollo_torch(test_case):
408+
"""
409+
Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
410+
https://github.com/zhuhanqing/APOLLO
411+
"""
412+
return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
413+
414+
406415
def require_lomo(test_case):
407416
"""
408417
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.

src/transformers/trainer.py

+114
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
find_labels,
152152
is_accelerate_available,
153153
is_apex_available,
154+
is_apollo_torch_available,
154155
is_bitsandbytes_available,
155156
is_datasets_available,
156157
is_galore_torch_available,
@@ -1582,6 +1583,119 @@ def optimizer_hook(param):
15821583

15831584
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
15841585
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
1586+
elif args.optim in [
1587+
OptimizerNames.APOLLO_ADAMW,
1588+
OptimizerNames.APOLLO_ADAMW_LAYERWISE,
1589+
]:
1590+
if not is_apollo_torch_available():
1591+
raise ImportError(
1592+
"You need to install `apollo_torch` in order to use APOLLO optimizers"
1593+
" install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
1594+
)
1595+
from apollo_torch import APOLLOAdamW
1596+
1597+
is_layerwise = args.optim.lower().endswith("layerwise")
1598+
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
1599+
raise NotImplementedError("Layer-wise APOLLO does not support DDP at this time")
1600+
1601+
optimizer_mapping = {
1602+
OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
1603+
OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
1604+
}
1605+
1606+
optimizer_cls = optimizer_mapping[args.optim]
1607+
1608+
if args.optim_target_modules is None:
1609+
raise ValueError(
1610+
"You need to define a `optim_target_modules` in order to properly use APOLLO optimizers"
1611+
)
1612+
1613+
if not isinstance(args.optim_target_modules, (list, str)):
1614+
raise ValueError(
1615+
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
1616+
)
1617+
1618+
if model is None:
1619+
raise ValueError("You need to pass a model in order to correctly initialize a APOLLO optimizer.")
1620+
1621+
all_linear = (
1622+
isinstance(args.optim_target_modules, str)
1623+
and args.optim_target_modules.replace("_", "-") == "all-linear"
1624+
)
1625+
1626+
apollo_params = []
1627+
apollo_params_names = []
1628+
for module_name, module in model.named_modules():
1629+
target_module_exists, is_regex = check_target_module_exists(
1630+
args.optim_target_modules, module_name, return_is_regex=True
1631+
)
1632+
1633+
if not isinstance(module, nn.Linear):
1634+
# Warn in case we match but it's not a linear layer
1635+
if target_module_exists and not is_regex:
1636+
logger.warning(
1637+
f"{module_name} has been matched but ignored as APOLLO only supports linear layers. Please double check your `optim_target_modules`!"
1638+
)
1639+
1640+
continue
1641+
1642+
if not target_module_exists and not all_linear:
1643+
continue
1644+
1645+
apollo_params.append(module.weight)
1646+
apollo_params_names.append(module_name + ".weight")
1647+
1648+
if len(apollo_params) == 0:
1649+
raise ValueError(
1650+
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
1651+
)
1652+
1653+
non_apollo_params = [p for n, p in model.named_parameters() if n not in apollo_params_names]
1654+
apollo_optim_kwargs = {
1655+
"rank": int(optim_args.pop("rank", 128)),
1656+
"proj": optim_args.pop("proj", "random"),
1657+
"scale_type": optim_args.pop("scale_type", "channel"),
1658+
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
1659+
"scale": float(optim_args.pop("scale", 1.0)),
1660+
"proj_type": optim_args.pop("proj_type", "std"),
1661+
}
1662+
1663+
# The default args are from the official repository: https://github.com/zhuhanqing/APOLLO
1664+
param_groups = [
1665+
{"params": non_apollo_params},
1666+
{"params": apollo_params, **apollo_optim_kwargs},
1667+
]
1668+
1669+
if is_layerwise:
1670+
# For layer-wise optimizers, the optimization step is done through post accumulation
1671+
# gradient hooks. The trick is to first attach these hooks to the model parameters then
1672+
# create a dummy optimizer that will perform no-ops in the Trainer.
1673+
# See the original implementation or the nice implementation from @hiyouga
1674+
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
1675+
if args.gradient_accumulation_steps != 1:
1676+
raise ValueError("Layerwise APOLLO optimizer do not support gradient accumulation !")
1677+
1678+
optimizer_dict = {}
1679+
for param in non_apollo_params:
1680+
param_groups = [{"params": [param]}]
1681+
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
1682+
for param in apollo_params:
1683+
param_groups = [{"params": [param], **apollo_optim_kwargs}]
1684+
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
1685+
1686+
def optimizer_hook(param):
1687+
if param.grad is not None:
1688+
optimizer_dict[param].step()
1689+
optimizer_dict[param].zero_grad()
1690+
1691+
for param in model.parameters():
1692+
if param.requires_grad:
1693+
param.register_post_accumulate_grad_hook(optimizer_hook)
1694+
1695+
optimizer_cls = LayerWiseDummyOptimizer
1696+
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
1697+
1698+
optimizer_kwargs.update({"params": param_groups})
15851699
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
15861700
if not is_lomo_available():
15871701
raise ImportError(

src/transformers/training_args.py

+2
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ class OptimizerNames(ExplicitEnum):
184184
GROKADAMW = "grokadamw"
185185
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
186186
SCHEDULE_FREE_SGD = "schedule_free_sgd"
187+
APOLLO_ADAMW = "apollo_adamw"
188+
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
187189

188190

189191
# Sometimes users will pass in a `str` repr of a dict in the CLI

src/transformers/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
get_torch_version,
119119
is_accelerate_available,
120120
is_apex_available,
121+
is_apollo_torch_available,
121122
is_aqlm_available,
122123
is_auto_awq_available,
123124
is_auto_gptq_available,

src/transformers/utils/import_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
9898

9999
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
100100
_apex_available = _is_package_available("apex")
101+
_apollo_torch_available = _is_package_available("apollo_torch")
101102
_aqlm_available = _is_package_available("aqlm")
102103
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
103104
_av_available = importlib.util.find_spec("av") is not None
@@ -402,6 +403,10 @@ def is_galore_torch_available():
402403
return _galore_torch_available
403404

404405

406+
def is_apollo_torch_available():
407+
return _apollo_torch_available
408+
409+
405410
def is_lomo_available():
406411
return _lomo_available
407412

0 commit comments

Comments
 (0)