Skip to content

Commit

Permalink
Add support for Mistral models (#411)
Browse files Browse the repository at this point in the history
* feat(text-generation): add support for Mistral models

* doc: add mistral to supported architectures

* test: use correct timy mistral model
  • Loading branch information
dacorvo authored Jan 16, 2024
1 parent 104bd64 commit 43d2f90
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/package_reference/export.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Since many architectures share similar properties for their Neuron configuration
| FlauBERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| GPT2 | text-generation |
| Llama, Llama 2 | text-generation |
| Mistral | text-generation |
| MobileBERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| MPNet | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| OPT | text-generation |
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,8 @@ def generate_io_aliases(self, model):
aliases[model.past_key_values_ca[i]] = len(model.past_key_values_sa) + i + num_outputs_from_trace

return aliases


@register_in_tasks_manager("mistral", "text-generation")
class MistralNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "mistral.model.MistralForSampling"
3 changes: 2 additions & 1 deletion tests/generation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from optimum.utils.testing_utils import USER


DECODER_MODEL_ARCHITECTURES = ["bloom", "gpt2", "llama", "opt"]
DECODER_MODEL_ARCHITECTURES = ["bloom", "gpt2", "llama", "mistral", "opt"]
DECODER_MODEL_NAMES = {
"bloom": "hf-internal-testing/tiny-random-BloomForCausalLM",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"llama": "dacorvo/tiny-random-llama",
"mistral": "dacorvo/tiny-random-MistralForCausalLM",
"opt": "hf-internal-testing/tiny-random-OPTForCausalLM",
}
SEQ2SEQ_MODEL_NAMES = {
Expand Down

0 comments on commit 43d2f90

Please sign in to comment.