Skip to content

Commit

Permalink
Support phi model on feature-extraction, text-classification, token-c…
Browse files Browse the repository at this point in the history
…lassification tasks (#509)

* support phi

* update doc

* fix style
  • Loading branch information
JingyaHuang authored Mar 11, 2024
1 parent 8f84127 commit b140cd5
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 6 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ transformers_examples:
# Run code quality checks
style_check:
black --check .
ruff .
ruff check .

style:
black .
ruff . --fix
ruff check . --fix

# Utilities to release to PyPi
build_dist_install_tools:
Expand Down
1 change: 1 addition & 0 deletions docs/source/package_reference/supported_models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
| MobileBERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| MPNet | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| OPT | text-generation |
| Phi | feature-extraction, text-classification, token-classification |
| RoBERTa | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| RoFormer | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| T5 | text2text-generation |
Expand Down
13 changes: 13 additions & 0 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
VisionNeuronConfig,
)
from .model_wrappers import (
NoCacheModelWrapper,
SentenceTransformersCLIPNeuronWrapper,
SentenceTransformersTransformerNeuronWrapper,
T5DecoderWrapper,
Expand Down Expand Up @@ -122,6 +123,18 @@ class MobileBertNeuronConfig(BertNeuronConfig):
pass


@register_in_tasks_manager("phi", *["feature-extraction", "text-classification", "token-classification"])
class PhiNeuronConfig(ElectraNeuronConfig):
CUSTOM_MODEL_WRAPPER = NoCacheModelWrapper

@property
def inputs(self) -> List[str]:
return ["input_ids", "attention_mask"]

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))


@register_in_tasks_manager("roformer", *COMMON_TEXT_TASKS)
class RoFormerNeuronConfig(ElectraNeuronConfig):
pass
Expand Down
13 changes: 13 additions & 0 deletions optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,16 @@ def forward(self, input_ids, pixel_values, attention_mask):
text_embeds = self.model[1:](text_embeds)

return (text_embeds, image_embeds)


class NoCacheModelWrapper(torch.nn.Module):
def __init__(self, model: "PreTrainedModel", input_names: List[str]):
super().__init__()
self.model = model
self.input_names = input_names

def forward(self, *input):
ordered_inputs = dict(zip(self.input_names, input))
outputs = self.model(use_cache=False, **ordered_inputs)

return outputs
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ target-version = ['py38']
extend-exclude = '.ipynb'

[tool.ruff]
line-length = 119

[tool.ruff.lint]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
line-length = 119

# Ignore import violations in all `__init__.py` files.
[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]

[tool.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["optimum.neuron"]

Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"flaubert": "flaubert/flaubert_small_cased",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"phi": "hf-internal-testing/tiny-random-PhiModel",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
"xlm": "hf-internal-testing/tiny-random-XLMModel",
Expand Down
1 change: 1 addition & 0 deletions tests/inference/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"latent-consistency": "echarlaix/tiny-random-latent-consistency",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"phi": "hf-internal-testing/tiny-random-PhiModel",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def fetch_model(
config = GenerationConfig.from_pretrained(model_id, revision=revision)
config.save_pretrained(export_path)
logger.info(f"Saved model default generation config under {export_path}.")
except:
except Exception:
logger.warning(f"No default generation config found for {model_id}.")
logger.info(f"Model successfully exported in {end - start:.2f} s under {export_path}.")
return export_path

0 comments on commit b140cd5

Please sign in to comment.