Skip to content

Commit 18aaba1

Browse files
lapp0rlouf
authored andcommitted
add outlines.models.mlxlm
1 parent 742eb89 commit 18aaba1

File tree

13 files changed

+645
-4
lines changed

13 files changed

+645
-4
lines changed

docs/reference/models/mlxlm.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# mlx-lm
2+
3+
Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx-examples/tree/main/llms), allowing models to be run quickly on Apple Silicon via the [mlx](https://ml-explore.github.io/mlx/build/html/index.html) library.
4+
5+
## Installation
6+
7+
In addition to `outlines`, you must install `mlx-lm` and `mlx` libraries. You must use a device which [supports Metal](https://support.apple.com/en-us/102894).
8+
9+
## Using `models.mlxlm`
10+
11+
```python
12+
from outlines import models
13+
14+
model = models.mlxlm("mlx-community/mlx-community/Meta-Llama-3-8B-Instruct-8bit")
15+
```
16+
17+
With the loaded model, you can generate text or perform structured generation, e.g.
18+
19+
```python3
20+
from outlines import models, generate
21+
22+
model = models.mlxlm("mlx-community/Meta-Llama-3-8B-Instruct-8bit")
23+
24+
phone_number_pattern = "\\+?[1-9][0-9]{7,14}"
25+
generator = generate.regex(model, phone_number_pattern)
26+
27+
model_output = generator("What's Jennys Number?\n")
28+
print(model_output)
29+
# '8675309'
30+
```
31+
32+
For more examples, see the [cookbook](cookbook/index.md).

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ nav:
125125
- vLLM: reference/models/vllm.md
126126
- Llama.cpp: reference/models/llamacpp.md
127127
- Transformers: reference/models/transformers.md
128+
- MLX: reference/models/mlxlm.md
128129
- ExllamaV2: reference/models/exllamav2.md
129130
- Mamba: reference/models/mamba.md
130131
- OpenAI: reference/models/openai.md

outlines/generate/cfg.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
55
from outlines.models import OpenAI
66
from outlines.models.llamacpp import LlamaCpp
7+
from outlines.models.mlxlm import MLXLM
78
from outlines.models.vllm import VLLM
89
from outlines.samplers import Sampler, multinomial
910

@@ -33,14 +34,15 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
3334
return generator
3435

3536

37+
@cfg.register(MLXLM)
3638
@cfg.register(VLLM)
37-
def cfg_vllm(
38-
model: VLLM,
39+
def cfg_unimplemented(
40+
model,
3941
cfg_str: str,
4042
sampler: Sampler = multinomial(),
4143
):
4244
raise NotImplementedError(
43-
"The CFG Logits processor is not available for the vLLM integration."
45+
f"The CFG Logits processor is not available for {type(model)}."
4446
)
4547

4648

outlines/generate/regex.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
55
from outlines.models import OpenAI
66
from outlines.models.llamacpp import LlamaCpp
7+
from outlines.models.mlxlm import MLXLM
78
from outlines.models.vllm import VLLM
89
from outlines.samplers import Sampler, multinomial
910

@@ -37,6 +38,18 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
3738
return generator
3839

3940

41+
@regex.register(MLXLM)
42+
def regex_mlxlm(
43+
model: MLXLM,
44+
regex_str: str,
45+
sampler: Sampler = multinomial(),
46+
):
47+
from outlines.processors import RegexLogitsProcessor
48+
49+
logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
50+
return SequenceGeneratorAdapter(model, logits_processor, sampler)
51+
52+
4053
@regex.register(LlamaCpp)
4154
def regex_llamacpp(
4255
model: LlamaCpp,

outlines/generate/text.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from outlines.fsm.guide import StopAtEOSGuide
44
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
5-
from outlines.models import VLLM, LlamaCpp, OpenAI
5+
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI
66
from outlines.samplers import Sampler, multinomial
77

88

@@ -36,6 +36,11 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
3636
return generator
3737

3838

39+
@text.register(MLXLM)
40+
def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()):
41+
return SequenceGeneratorAdapter(model, None, sampler)
42+
43+
3944
@text.register(VLLM)
4045
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
4146
return SequenceGeneratorAdapter(model, None, sampler)

outlines/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .exllamav2 import ExLlamaV2Model, exl2
1111
from .llamacpp import LlamaCpp, llamacpp
1212
from .mamba import Mamba, mamba
13+
from .mlxlm import MLXLM, mlxlm
1314
from .openai import OpenAI, azure_openai, openai
1415
from .transformers import Transformers, transformers
1516
from .vllm import VLLM, vllm

outlines/models/mlxlm.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import dataclasses
2+
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union
3+
4+
from .transformers import TransformerTokenizer
5+
6+
if TYPE_CHECKING:
7+
import mlx.core as mx
8+
import mlx.nn as nn
9+
from transformers import PreTrainedTokenizer
10+
11+
from outlines.generate.api import GenerationParameters, SamplingParameters
12+
from outlines.processors import BaseLogitsProcessor
13+
14+
15+
class MLXLM:
16+
"""
17+
Represents an `mlx_lm` model
18+
"""
19+
20+
def __init__(
21+
self,
22+
model: "nn.Module",
23+
tokenizer: "PreTrainedTokenizer",
24+
):
25+
self.model = model
26+
self.mlx_tokenizer = tokenizer # returns mlx tensors, used for encode()
27+
self.tokenizer = TransformerTokenizer(
28+
tokenizer._tokenizer
29+
) # _tokenizer is HF Tokenizer
30+
31+
def generate(
32+
self,
33+
prompts: Union[str, List[str]],
34+
generation_parameters: "GenerationParameters",
35+
logits_processor,
36+
sampling_parameters: "SamplingParameters",
37+
) -> str:
38+
streamer = self.stream(
39+
prompts, generation_parameters, logits_processor, sampling_parameters
40+
)
41+
return "".join(list(streamer))
42+
43+
def stream(
44+
self,
45+
prompts: Union[str, List[str]],
46+
generation_parameters: "GenerationParameters",
47+
logits_processor,
48+
sampling_parameters: "SamplingParameters",
49+
) -> Iterator[str]:
50+
"""Generate text using `mlx_lm`.
51+
52+
Arguments
53+
---------
54+
prompts
55+
A prompt or list of prompts.
56+
generation_parameters
57+
An instance of `GenerationParameters` that contains the prompt,
58+
the maximum number of tokens, stop sequences and seed. All the
59+
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
60+
logits_processor
61+
The logits processor to use when generating text.
62+
sampling_parameters
63+
An instance of `SamplingParameters`, a dataclass that contains
64+
the name of the sampler to use and related parameters as available
65+
in Outlines.
66+
Returns
67+
-------
68+
The generated text.
69+
"""
70+
import mlx.core as mx
71+
72+
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
73+
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
74+
sampling_parameters
75+
)
76+
if max_tokens is None:
77+
max_tokens = int(1e9)
78+
79+
if not isinstance(prompts, str):
80+
raise NotImplementedError(
81+
"The `mlx-lm` library does not support batch inference."
82+
)
83+
if sampler == "beam_search":
84+
raise NotImplementedError(
85+
"The `mlx-lm` library does not support Beam Search."
86+
)
87+
if num_samples != 1:
88+
raise NotImplementedError(
89+
"The `mlx-lm` library does not allow to take several samples."
90+
)
91+
if top_k is not None:
92+
raise NotImplementedError("The `mlx-lm` library does not support top_k.")
93+
if seed is not None:
94+
raise NotImplementedError("The `mlx-lm` library does not support seed.")
95+
if stop_at is not None:
96+
raise NotImplementedError("The `mlx-lm` library does not support stop_at.")
97+
98+
generate_kwargs = {
99+
"temp": temperature,
100+
"top_p": top_p,
101+
"sampler": sampler,
102+
"logits_processor": logits_processor,
103+
}
104+
105+
# Adapted from
106+
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267
107+
prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts))
108+
109+
for (token, prob), n in zip(
110+
self.generate_step(prompt_tokens, **generate_kwargs),
111+
range(max_tokens),
112+
):
113+
if token == self.tokenizer.eos_token_id:
114+
break
115+
yield self.tokenizer.decode([token])[0]
116+
117+
def generate_step(
118+
self,
119+
prompt: "mx.array",
120+
temp: Optional[float],
121+
top_p: Optional[float],
122+
sampler: str,
123+
logits_processor: "BaseLogitsProcessor",
124+
) -> Generator[Tuple[int, float], None, None]:
125+
"""
126+
Adapted from
127+
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
128+
129+
A generator producing token ids based on the given prompt from the model.
130+
131+
Args:
132+
prompt (mx.array): The input prompt.
133+
temp (float): The temperature for sampling, if 0 the argmax is used.
134+
Default: ``0``.
135+
top_p (float, optional): Nulceus sampling, higher means model considers
136+
more less likely words.
137+
sampler (str): The sampler string defined by SequenceGeneratorAdapter
138+
logits_processor (BaseLogitsProcessor): Augment logits before sampling.
139+
"""
140+
import mlx.core as mx
141+
import mlx_lm
142+
143+
temperature: float = temp or 1.0
144+
145+
def sample(logits: "mx.array") -> Tuple["mx.array", float]:
146+
softmax_logits = mx.softmax(logits)
147+
148+
if temperature == 0.0 or sampler == "greedy":
149+
token = mx.argmax(logits, axis=-1)
150+
elif sampler == "multinomial":
151+
if top_p is not None and top_p > 0 and top_p < 1.0:
152+
token = mlx_lm.sample_utils.top_p_sampling(
153+
logits, top_p, temperature
154+
)
155+
else:
156+
token = mx.random.categorical(logits * (1 / temperature))
157+
else:
158+
raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`")
159+
160+
prob = softmax_logits[0, token]
161+
return token, prob
162+
163+
kv_heads = (
164+
[self.model.n_kv_heads] * len(self.model.layers)
165+
if isinstance(self.model.n_kv_heads, int)
166+
else self.model.n_kv_heads
167+
)
168+
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads]
169+
170+
# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
171+
unprocessed_input_ids = prompt
172+
generated_ids: List[int] = []
173+
174+
while True:
175+
logits = self.model(unprocessed_input_ids[None], cache=cache)
176+
logits = logits[:, -1, :]
177+
178+
if logits_processor is not None:
179+
# convert to logits_processor 1d expectation, apply, then convert back
180+
logits_1d = logits.reshape(-1)
181+
logits_1d = logits_processor(generated_ids, logits_1d)
182+
logits = logits_1d.reshape(1, -1)
183+
184+
new_token_single, prob = sample(logits)
185+
new_token = new_token_single.item()
186+
yield new_token, prob
187+
188+
generated_ids.append(new_token)
189+
unprocessed_input_ids = new_token_single
190+
191+
192+
def mlxlm(
193+
model_name: str,
194+
tokenizer_config: dict = {},
195+
model_config: dict = {},
196+
adapter_path: Optional[str] = None,
197+
lazy: bool = False,
198+
):
199+
"""Instantiate a model from the `mlx_lm` library and its tokenizer.
200+
201+
Signature adapted from
202+
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422
203+
204+
Parameters
205+
----------
206+
Args:
207+
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
208+
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
209+
Defaults to an empty dictionary.
210+
model_config(dict, optional): Configuration parameters specifically for the model.
211+
Defaults to an empty dictionary.
212+
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
213+
to the model. Default: ``None``.
214+
lazy (bool): If False eval the model parameters to make sure they are
215+
loaded in memory before returning, otherwise they will be loaded
216+
when needed. Default: ``False``
217+
218+
Returns
219+
-------
220+
A `MLXLM` model instance.
221+
222+
"""
223+
try:
224+
import mlx.core as mx
225+
import mlx_lm
226+
except ImportError:
227+
raise ImportError(
228+
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models."
229+
)
230+
if not mx.metal.is_available():
231+
raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)")
232+
233+
model, tokenizer = mlx_lm.load(
234+
model_name,
235+
tokenizer_config=tokenizer_config,
236+
model_config=model_config,
237+
adapter_path=adapter_path,
238+
lazy=lazy,
239+
)
240+
return MLXLM(model, tokenizer)

outlines/processors/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .structured import (
2+
BaseLogitsProcessor,
3+
CFGLogitsProcessor,
4+
FSMLogitsProcessor,
5+
JSONLogitsProcessor,
6+
RegexLogitsProcessor,
7+
)

0 commit comments

Comments
 (0)