|
151 | 151 | find_labels,
|
152 | 152 | is_accelerate_available,
|
153 | 153 | is_apex_available,
|
| 154 | + is_apollo_torch_available, |
154 | 155 | is_bitsandbytes_available,
|
155 | 156 | is_datasets_available,
|
156 | 157 | is_galore_torch_available,
|
@@ -1582,6 +1583,119 @@ def optimizer_hook(param):
|
1582 | 1583 |
|
1583 | 1584 | if args.optim == OptimizerNames.GALORE_ADAFACTOR:
|
1584 | 1585 | optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
| 1586 | + elif args.optim in [ |
| 1587 | + OptimizerNames.APOLLO_ADAMW, |
| 1588 | + OptimizerNames.APOLLO_ADAMW_LAYERWISE, |
| 1589 | + ]: |
| 1590 | + if not is_apollo_torch_available(): |
| 1591 | + raise ImportError( |
| 1592 | + "You need to install `apollo_torch` in order to use APOLLO optimizers" |
| 1593 | + " install it with `pip install git+https://github.com/zhuhanqing/APOLLO`" |
| 1594 | + ) |
| 1595 | + from apollo_torch import APOLLOAdamW |
| 1596 | + |
| 1597 | + is_layerwise = args.optim.lower().endswith("layerwise") |
| 1598 | + if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED: |
| 1599 | + raise NotImplementedError("Layer-wise APOLLO does not support DDP at this time") |
| 1600 | + |
| 1601 | + optimizer_mapping = { |
| 1602 | + OptimizerNames.APOLLO_ADAMW: APOLLOAdamW, |
| 1603 | + OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW, |
| 1604 | + } |
| 1605 | + |
| 1606 | + optimizer_cls = optimizer_mapping[args.optim] |
| 1607 | + |
| 1608 | + if args.optim_target_modules is None: |
| 1609 | + raise ValueError( |
| 1610 | + "You need to define a `optim_target_modules` in order to properly use APOLLO optimizers" |
| 1611 | + ) |
| 1612 | + |
| 1613 | + if not isinstance(args.optim_target_modules, (list, str)): |
| 1614 | + raise ValueError( |
| 1615 | + f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}" |
| 1616 | + ) |
| 1617 | + |
| 1618 | + if model is None: |
| 1619 | + raise ValueError("You need to pass a model in order to correctly initialize a APOLLO optimizer.") |
| 1620 | + |
| 1621 | + all_linear = ( |
| 1622 | + isinstance(args.optim_target_modules, str) |
| 1623 | + and args.optim_target_modules.replace("_", "-") == "all-linear" |
| 1624 | + ) |
| 1625 | + |
| 1626 | + apollo_params = [] |
| 1627 | + apollo_params_names = [] |
| 1628 | + for module_name, module in model.named_modules(): |
| 1629 | + target_module_exists, is_regex = check_target_module_exists( |
| 1630 | + args.optim_target_modules, module_name, return_is_regex=True |
| 1631 | + ) |
| 1632 | + |
| 1633 | + if not isinstance(module, nn.Linear): |
| 1634 | + # Warn in case we match but it's not a linear layer |
| 1635 | + if target_module_exists and not is_regex: |
| 1636 | + logger.warning( |
| 1637 | + f"{module_name} has been matched but ignored as APOLLO only supports linear layers. Please double check your `optim_target_modules`!" |
| 1638 | + ) |
| 1639 | + |
| 1640 | + continue |
| 1641 | + |
| 1642 | + if not target_module_exists and not all_linear: |
| 1643 | + continue |
| 1644 | + |
| 1645 | + apollo_params.append(module.weight) |
| 1646 | + apollo_params_names.append(module_name + ".weight") |
| 1647 | + |
| 1648 | + if len(apollo_params) == 0: |
| 1649 | + raise ValueError( |
| 1650 | + f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`." |
| 1651 | + ) |
| 1652 | + |
| 1653 | + non_apollo_params = [p for n, p in model.named_parameters() if n not in apollo_params_names] |
| 1654 | + apollo_optim_kwargs = { |
| 1655 | + "rank": int(optim_args.pop("rank", 128)), |
| 1656 | + "proj": optim_args.pop("proj", "random"), |
| 1657 | + "scale_type": optim_args.pop("scale_type", "channel"), |
| 1658 | + "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), |
| 1659 | + "scale": float(optim_args.pop("scale", 1.0)), |
| 1660 | + "proj_type": optim_args.pop("proj_type", "std"), |
| 1661 | + } |
| 1662 | + |
| 1663 | + # The default args are from the official repository: https://github.com/zhuhanqing/APOLLO |
| 1664 | + param_groups = [ |
| 1665 | + {"params": non_apollo_params}, |
| 1666 | + {"params": apollo_params, **apollo_optim_kwargs}, |
| 1667 | + ] |
| 1668 | + |
| 1669 | + if is_layerwise: |
| 1670 | + # For layer-wise optimizers, the optimization step is done through post accumulation |
| 1671 | + # gradient hooks. The trick is to first attach these hooks to the model parameters then |
| 1672 | + # create a dummy optimizer that will perform no-ops in the Trainer. |
| 1673 | + # See the original implementation or the nice implementation from @hiyouga |
| 1674 | + # here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba |
| 1675 | + if args.gradient_accumulation_steps != 1: |
| 1676 | + raise ValueError("Layerwise APOLLO optimizer do not support gradient accumulation !") |
| 1677 | + |
| 1678 | + optimizer_dict = {} |
| 1679 | + for param in non_apollo_params: |
| 1680 | + param_groups = [{"params": [param]}] |
| 1681 | + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) |
| 1682 | + for param in apollo_params: |
| 1683 | + param_groups = [{"params": [param], **apollo_optim_kwargs}] |
| 1684 | + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) |
| 1685 | + |
| 1686 | + def optimizer_hook(param): |
| 1687 | + if param.grad is not None: |
| 1688 | + optimizer_dict[param].step() |
| 1689 | + optimizer_dict[param].zero_grad() |
| 1690 | + |
| 1691 | + for param in model.parameters(): |
| 1692 | + if param.requires_grad: |
| 1693 | + param.register_post_accumulate_grad_hook(optimizer_hook) |
| 1694 | + |
| 1695 | + optimizer_cls = LayerWiseDummyOptimizer |
| 1696 | + optimizer_kwargs.update({"optimizer_dict": optimizer_dict}) |
| 1697 | + |
| 1698 | + optimizer_kwargs.update({"params": param_groups}) |
1585 | 1699 | elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
1586 | 1700 | if not is_lomo_available():
|
1587 | 1701 | raise ImportError(
|
|
0 commit comments