Skip to content

Commit eadb1c3

Browse files
authored
Merge pull request #30 from lapp0/fix-918-mlx
`outlines.models.mlxlm`
2 parents ed44a47 + e2d8a5c commit eadb1c3

File tree

13 files changed

+645
-4
lines changed

13 files changed

+645
-4
lines changed

docs/reference/models/mlxlm.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# mlx-lm
2+
3+
Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx-examples/tree/main/llms), allowing models to be run quickly on Apple Silicon via the [mlx](https://ml-explore.github.io/mlx/build/html/index.html) library.
4+
5+
## Installation
6+
7+
In addition to `outlines`, you must install `mlx-lm` and `mlx` libraries. You must use a device which [supports Metal](https://support.apple.com/en-us/102894).
8+
9+
## Using `models.mlxlm`
10+
11+
```python
12+
from outlines import models
13+
14+
model = models.mlxlm("mlx-community/mlx-community/Meta-Llama-3-8B-Instruct-8bit")
15+
```
16+
17+
With the loaded model, you can generate text or perform structured generation, e.g.
18+
19+
```python3
20+
from outlines import models, generate
21+
22+
model = models.mlxlm("mlx-community/Meta-Llama-3-8B-Instruct-8bit")
23+
24+
phone_number_pattern = "\\+?[1-9][0-9]{7,14}"
25+
generator = generate.regex(model, phone_number_pattern)
26+
27+
model_output = generator("What's Jennys Number?\n")
28+
print(model_output)
29+
# '8675309'
30+
```
31+
32+
For more examples, see the [cookbook](cookbook/index.md).

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ nav:
124124
- vLLM: reference/models/vllm.md
125125
- Llama.cpp: reference/models/llamacpp.md
126126
- Transformers: reference/models/transformers.md
127+
- MLX: reference/models/mlxlm.md
127128
- ExllamaV2: reference/models/exllamav2.md
128129
- Mamba: reference/models/mamba.md
129130
- OpenAI: reference/models/openai.md

outlines/generate/cfg.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
55
from outlines.models import OpenAI
66
from outlines.models.llamacpp import LlamaCpp
7+
from outlines.models.mlxlm import MLXLM
78
from outlines.models.vllm import VLLM
89
from outlines.samplers import Sampler, multinomial
910

@@ -33,14 +34,15 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
3334
return generator
3435

3536

37+
@cfg.register(MLXLM)
3638
@cfg.register(VLLM)
37-
def cfg_vllm(
38-
model: VLLM,
39+
def cfg_unimplemented(
40+
model,
3941
cfg_str: str,
4042
sampler: Sampler = multinomial(),
4143
):
4244
raise NotImplementedError(
43-
"The CFG Logits processor is not available for the vLLM integration."
45+
f"The CFG Logits processor is not available for {type(model)}."
4446
)
4547

4648

outlines/generate/regex.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
55
from outlines.models import OpenAI
66
from outlines.models.llamacpp import LlamaCpp
7+
from outlines.models.mlxlm import MLXLM
78
from outlines.models.vllm import VLLM
89
from outlines.samplers import Sampler, multinomial
910

@@ -37,6 +38,18 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
3738
return generator
3839

3940

41+
@regex.register(MLXLM)
42+
def regex_mlxlm(
43+
model: MLXLM,
44+
regex_str: str,
45+
sampler: Sampler = multinomial(),
46+
):
47+
from outlines.processors import RegexLogitsProcessor
48+
49+
logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
50+
return SequenceGeneratorAdapter(model, logits_processor, sampler)
51+
52+
4053
@regex.register(LlamaCpp)
4154
def regex_llamacpp(
4255
model: LlamaCpp,

outlines/generate/text.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from outlines.fsm.guide import StopAtEOSGuide
44
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
5-
from outlines.models import VLLM, LlamaCpp, OpenAI
5+
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI
66
from outlines.samplers import Sampler, multinomial
77

88

@@ -36,6 +36,11 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
3636
return generator
3737

3838

39+
@text.register(MLXLM)
40+
def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()):
41+
return SequenceGeneratorAdapter(model, None, sampler)
42+
43+
3944
@text.register(VLLM)
4045
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
4146
return SequenceGeneratorAdapter(model, None, sampler)

outlines/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .exllamav2 import ExLlamaV2Model, exl2
1111
from .llamacpp import LlamaCpp, llamacpp
1212
from .mamba import Mamba, mamba
13+
from .mlxlm import MLXLM, mlxlm
1314
from .openai import OpenAI, azure_openai, openai
1415
from .transformers import Transformers, transformers
1516
from .vllm import VLLM, vllm

outlines/models/mlxlm.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import dataclasses
2+
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union
3+
4+
from .transformers import TransformerTokenizer
5+
6+
if TYPE_CHECKING:
7+
import mlx.nn as nn
8+
from transformers import PreTrainedTokenizer
9+
10+
from outlines.generate.api import GenerationParameters, SamplingParameters
11+
from outlines.processors import BaseLogitsProcessor
12+
13+
try:
14+
import mlx.core as mx
15+
import mlx_lm
16+
except ImportError:
17+
pass
18+
19+
20+
class MLXLM:
21+
"""
22+
Represents an `mlx_lm` model
23+
"""
24+
25+
def __init__(
26+
self,
27+
model: "nn.Module",
28+
tokenizer: "PreTrainedTokenizer",
29+
):
30+
self.model = model
31+
self.mlx_tokenizer = tokenizer # returns mlx tensors, used for encode()
32+
self.tokenizer = TransformerTokenizer(
33+
tokenizer._tokenizer
34+
) # _tokenizer is HF Tokenizer
35+
36+
def generate(
37+
self,
38+
prompts: Union[str, List[str]],
39+
generation_parameters: "GenerationParameters",
40+
logits_processor,
41+
sampling_parameters: "SamplingParameters",
42+
) -> str:
43+
streamer = self.stream(
44+
prompts, generation_parameters, logits_processor, sampling_parameters
45+
)
46+
return "".join(list(streamer))
47+
48+
def stream(
49+
self,
50+
prompts: Union[str, List[str]],
51+
generation_parameters: "GenerationParameters",
52+
logits_processor,
53+
sampling_parameters: "SamplingParameters",
54+
) -> Iterator[str]:
55+
"""Generate text using `mlx_lm`.
56+
57+
Arguments
58+
---------
59+
prompts
60+
A prompt or list of prompts.
61+
generation_parameters
62+
An instance of `GenerationParameters` that contains the prompt,
63+
the maximum number of tokens, stop sequences and seed. All the
64+
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
65+
logits_processor
66+
The logits processor to use when generating text.
67+
sampling_parameters
68+
An instance of `SamplingParameters`, a dataclass that contains
69+
the name of the sampler to use and related parameters as available
70+
in Outlines.
71+
Returns
72+
-------
73+
The generated text.
74+
"""
75+
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
76+
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
77+
sampling_parameters
78+
)
79+
if max_tokens is None:
80+
max_tokens = int(1e9)
81+
82+
if not isinstance(prompts, str):
83+
raise NotImplementedError(
84+
"The `mlx-lm` library does not support batch inference."
85+
)
86+
if sampler == "beam_search":
87+
raise NotImplementedError(
88+
"The `mlx-lm` library does not support Beam Search."
89+
)
90+
if num_samples != 1:
91+
raise NotImplementedError(
92+
"The `mlx-lm` library does not allow to take several samples."
93+
)
94+
if top_k is not None:
95+
raise NotImplementedError("The `mlx-lm` library does not support top_k.")
96+
if seed is not None:
97+
raise NotImplementedError("The `mlx-lm` library does not support seed.")
98+
if stop_at is not None:
99+
raise NotImplementedError("The `mlx-lm` library does not support stop_at.")
100+
101+
generate_kwargs = {
102+
"temp": temperature,
103+
"top_p": top_p,
104+
"sampler": sampler,
105+
"logits_processor": logits_processor,
106+
}
107+
108+
# Adapted from
109+
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267
110+
prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts))
111+
112+
for (token, prob), n in zip(
113+
self.generate_step(prompt_tokens, **generate_kwargs),
114+
range(max_tokens),
115+
):
116+
if token == self.tokenizer.eos_token_id:
117+
break
118+
yield self.tokenizer.decode([token])[0]
119+
120+
def generate_step(
121+
self,
122+
prompt: "mx.array",
123+
temp: Optional[float],
124+
top_p: Optional[float],
125+
sampler: str,
126+
logits_processor: "BaseLogitsProcessor",
127+
) -> Generator[Tuple[int, float], None, None]:
128+
"""
129+
Adapted from
130+
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
131+
132+
A generator producing token ids based on the given prompt from the model.
133+
134+
Args:
135+
prompt (mx.array): The input prompt.
136+
temp (float): The temperature for sampling, if 0 the argmax is used.
137+
Default: ``0``.
138+
top_p (float, optional): Nulceus sampling, higher means model considers
139+
more less likely words.
140+
sampler (str): The sampler string defined by SequenceGeneratorAdapter
141+
logits_processor (BaseLogitsProcessor): Augment logits before sampling.
142+
"""
143+
temperature: float = temp or 1.0
144+
145+
def sample(logits: "mx.array") -> Tuple["mx.array", float]:
146+
softmax_logits = mx.softmax(logits)
147+
148+
if temperature == 0.0 or sampler == "greedy":
149+
token = mx.argmax(logits, axis=-1)
150+
elif sampler == "multinomial":
151+
if top_p is not None and top_p > 0 and top_p < 1.0:
152+
token = mlx_lm.sample_utils.top_p_sampling(
153+
logits, top_p, temperature
154+
)
155+
else:
156+
token = mx.random.categorical(logits * (1 / temperature))
157+
else:
158+
raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`")
159+
160+
prob = softmax_logits[0, token]
161+
return token, prob
162+
163+
kv_heads = (
164+
[self.model.n_kv_heads] * len(self.model.layers)
165+
if isinstance(self.model.n_kv_heads, int)
166+
else self.model.n_kv_heads
167+
)
168+
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads]
169+
170+
# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
171+
unprocessed_input_ids = prompt
172+
generated_ids: List[int] = []
173+
174+
while True:
175+
logits = self.model(unprocessed_input_ids[None], cache=cache)
176+
logits = logits[:, -1, :]
177+
178+
if logits_processor is not None:
179+
# convert to logits_processor 1d expectation, apply, then convert back
180+
logits_1d = logits.reshape(-1)
181+
logits_1d = logits_processor(generated_ids, logits_1d)
182+
logits = logits_1d.reshape(1, -1)
183+
184+
new_token_single, prob = sample(logits)
185+
new_token = new_token_single.item()
186+
yield new_token, prob
187+
188+
generated_ids.append(new_token)
189+
unprocessed_input_ids = new_token_single
190+
191+
192+
def mlxlm(
193+
model_name: str,
194+
tokenizer_config: dict = {},
195+
model_config: dict = {},
196+
adapter_path: Optional[str] = None,
197+
lazy: bool = False,
198+
):
199+
"""Instantiate a model from the `mlx_lm` library and its tokenizer.
200+
201+
Signature adapted from
202+
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422
203+
204+
Parameters
205+
----------
206+
Args:
207+
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
208+
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
209+
Defaults to an empty dictionary.
210+
model_config(dict, optional): Configuration parameters specifically for the model.
211+
Defaults to an empty dictionary.
212+
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
213+
to the model. Default: ``None``.
214+
lazy (bool): If False eval the model parameters to make sure they are
215+
loaded in memory before returning, otherwise they will be loaded
216+
when needed. Default: ``False``
217+
218+
Returns
219+
-------
220+
A `MLXLM` model instance.
221+
222+
"""
223+
try:
224+
import mlx.core as mx
225+
import mlx_lm
226+
except ImportError:
227+
raise ImportError(
228+
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models."
229+
)
230+
if not mx.metal.is_available():
231+
raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)")
232+
233+
model, tokenizer = mlx_lm.load(
234+
model_name,
235+
tokenizer_config=tokenizer_config,
236+
model_config=model_config,
237+
adapter_path=adapter_path,
238+
lazy=lazy,
239+
)
240+
return MLXLM(model, tokenizer)

outlines/processors/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .structured import (
2+
BaseLogitsProcessor,
3+
CFGLogitsProcessor,
4+
FSMLogitsProcessor,
5+
JSONLogitsProcessor,
6+
RegexLogitsProcessor,
7+
)

0 commit comments

Comments
 (0)