Skip to content

Commit

Permalink
Add support for Mixtral (#569)
Browse files Browse the repository at this point in the history
feat(decoder): add support for Mixtral
  • Loading branch information
dacorvo authored Apr 15, 2024
1 parent 4429bb6 commit c3daf50
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/source/package_reference/supported_models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
| GPT2 | text-generation |
| Llama, Llama 2 | text-generation |
| Mistral | text-generation |
| Mixtral | text-generation |
| MobileBERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| MPNet | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| OPT | text-generation |
Expand Down Expand Up @@ -69,4 +70,4 @@ More details for checking supported tasks [here](https://huggingface.co/docs/opt

</Tip>

More architectures coming soon, stay tuned! 🚀
More architectures coming soon, stay tuned! 🚀
6 changes: 6 additions & 0 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,9 @@ def generate_io_aliases(self, model):
class MistralNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "mistral.model.MistralForSampling"
CONTINUOUS_BATCHING = True


@register_in_tasks_manager("mixtral", "text-generation")
class MixtralNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "mixtral.model.MixtralForSampling"
CONTINUOUS_BATCHING = True
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"xlm",
"roberta",
]
DECODER_ARCHITECTURES = ["gpt2", "llama"]
DECODER_ARCHITECTURES = ["gpt2", "llama", "mixtral"]
DIFFUSER_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl"]

INFERENTIA_MODEL_NAMES = {
Expand All @@ -46,6 +46,7 @@
"flaubert": "flaubert/flaubert_small_cased",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"llama": "dacorvo/tiny-random-llama",
"mixtral": "dacorvo/Mixtral-tiny",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
Expand Down
3 changes: 2 additions & 1 deletion tests/generation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
from optimum.utils.testing_utils import USER


DECODER_MODEL_ARCHITECTURES = ["bloom", "gpt2", "llama", "mistral", "opt"]
DECODER_MODEL_ARCHITECTURES = ["bloom", "gpt2", "llama", "mistral", "mixtral", "opt"]
DECODER_MODEL_NAMES = {
"bloom": "hf-internal-testing/tiny-random-BloomForCausalLM",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"llama": "dacorvo/tiny-random-llama",
"mistral": "dacorvo/tiny-random-MistralForCausalLM",
"mixtral": "dacorvo/Mixtral-tiny",
"opt": "hf-internal-testing/tiny-random-OPTForCausalLM",
}
TRN_DECODER_MODEL_ARCHITECTURES = ["bloom", "llama", "opt"]
Expand Down

0 comments on commit c3daf50

Please sign in to comment.