Skip to content

Commit

Permalink
Merge branch 'main' into add-pixart-support
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Nov 4, 2024
2 parents 687de69 + 02c331d commit a5f76a3
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 18 deletions.
8 changes: 4 additions & 4 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

_import_structure = {
"hf_argparser": ["NeuronHfArgumentParser"],
"trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer", "NeuronSFTTrainer"],
"trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer", "NeuronSFTTrainer", "NeuronORPOTrainer"],
"training_args": ["NeuronTrainingArguments", "Seq2SeqNeuronTrainingArguments"],
"modeling_traced": ["NeuronTracedModel"],
"modeling": [
Expand Down Expand Up @@ -69,7 +69,7 @@
"ModelParallelismPlugin",
],
"pipelines": ["pipeline"],
"utils": ["NeuronSFTConfig", "get_peft_model"],
"utils": ["NeuronSFTConfig", "NeuronORPOConfig", "get_peft_model"],
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,9 +109,9 @@
from .modeling_seq2seq import NeuronModelForSeq2SeqLM
from .modeling_traced import NeuronTracedModel
from .pipelines import pipeline
from .trainers import NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer
from .trainers import NeuronORPOTrainer, NeuronSFTTrainer, NeuronTrainer, Seq2SeqNeuronTrainer
from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments
from .utils import NeuronSFTConfig, get_peft_model
from .utils import NeuronORPOConfig, NeuronSFTConfig, get_peft_model

else:
import sys
Expand Down
355 changes: 350 additions & 5 deletions optimum/neuron/trainers.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class NeuronTrainingArgumentsMixin:
)

def __post_init__(self):
if self.neuron_cc_flags_model_type is not None:
os.environ["OPTIMUM_NEURON_COMMON_FLAGS_MODEL_TYPE"] = self.neuron_cc_flags_model_type

# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
patch_accelerate_is_torch_xla_available()

Expand Down Expand Up @@ -221,6 +224,11 @@ def __post_init__(self):
def _setup_devices(self) -> "torch.device":
return super()._setup_devices

@property
def neuron_cc_flags_model_type(self) -> Optional[str]:
"""Controls the value to provide to the Neuron Compiler for the model-type flag."""
return "transformer"

@property
def place_model_on_device(self):
return not self.mp_plugin.should_parallelize and super().place_model_on_device
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"is_model_officially_supported",
"patch_transformers_for_neuron_sdk",
],
"trl_utils": ["NeuronSFTConfig"],
"trl_utils": ["NeuronSFTConfig", "NeuronORPOConfig"],
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -137,7 +137,7 @@
is_model_officially_supported,
patch_transformers_for_neuron_sdk,
)
from .trl_utils import NeuronSFTConfig
from .trl_utils import NeuronORPOConfig, NeuronSFTConfig
else:
import sys

Expand Down
10 changes: 7 additions & 3 deletions optimum/neuron/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,16 @@ def is_torch_neuronx_available() -> bool:
return importlib.util.find_spec("torch_neuronx") is not None


def is_trl_available() -> bool:
def is_trl_available(required_version: Optional[str] = None) -> bool:
trl_available = importlib.util.find_spec("trl") is not None
if trl_available:
import trl

if version.parse(trl.__version__) >= version.parse("0.10.0"):
if required_version is None:
required_version = trl.__version__

if version.parse(trl.__version__) == version.parse(required_version):
return True
raise RuntimeError("Only `trl` 0.10.0 and more recent is supported.")

raise RuntimeError(f"Only `trl=={required_version}` is supported, but {trl.__version__} is installed.")
return False
5 changes: 3 additions & 2 deletions optimum/neuron/utils/torch_xla_and_neuronx_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def set_common_flags():
"""
Sets environment variables for transformer-based models training with AWS Neuron.
"""
# Set compiler flag to compile for transformer model type
os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --model-type=transformer"
model_type = os.environ.get("OPTIMUM_NEURON_COMMON_FLAGS_MODEL_TYPE", "")
if model_type != "":
os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + f" --model-type={model_type}"
# Setting MALLOC_ARENA_MAX is needed because of a memory issue in XLA/glic, otherwise OOM can happen during
# checkpointing. More information here:
# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/torch/torch-neuronx/index.html#memory-leaking-in-glibc
Expand Down
16 changes: 15 additions & 1 deletion optimum/neuron/utils/trl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,35 @@
"""Utilities related to the TRL library and support."""

from dataclasses import dataclass
from typing import Optional

from ..training_args import NeuronTrainingArguments
from .import_utils import is_trl_available


if is_trl_available():
from trl import SFTConfig
from trl import ORPOConfig, SFTConfig
else:

@dataclass
class SFTConfig:
def __init__(self, *args, **kwargs):
raise RuntimeError("You need to install the `trl` library to use the `NeuronSFTConfig`.")

@dataclass
class ORPOConfig:
def __init__(self, *args, **kwargs):
raise RuntimeError("You need to install the `trl` library to use the `NeuronSFTConfig`.")


@dataclass
class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig):
pass


@dataclass
class NeuronORPOConfig(NeuronTrainingArguments, ORPOConfig):

@property
def neuron_cc_flags_model_type(self) -> Optional[str]:
return None
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"safetensors",
"sentence-transformers >= 2.2.0",
"peft",
"trl",
"trl==0.11.4",
"compel",
"rjieba",
"soundfile",
Expand Down

0 comments on commit a5f76a3

Please sign in to comment.