Skip to content

Commit d30b33c

Browse files
committed
add safe_subset argument to json_schema.to_regex, implement integer minimum / maximum
1 parent 72377db commit d30b33c

File tree

2 files changed

+309
-37
lines changed

2 files changed

+309
-37
lines changed

outlines/fsm/json_schema.py

Lines changed: 193 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import inspect
2+
import itertools
23
import json
4+
import math
35
import re
46
import warnings
57
from typing import Callable, Optional, Tuple, Type, Union
@@ -18,14 +20,16 @@
1820
NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
1921
BOOLEAN = r"(true|false)"
2022
NULL = r"null"
21-
WHITESPACE = r"[ ]?"
23+
WHITESPACE = r"[\n\t ]*"
24+
SAFE_WHITESPACE = r"[ ]?"
25+
2226

2327
type_to_regex = {
2428
"string": STRING,
25-
"integer": INTEGER,
2629
"number": NUMBER,
2730
"boolean": BOOLEAN,
2831
"null": NULL,
32+
"integer": INTEGER,
2933
}
3034

3135
DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"'
@@ -41,7 +45,145 @@
4145
}
4246

4347

44-
def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None):
48+
def get_subranges(minimum, maximum):
49+
"""
50+
Convert a range into a list of subranges which can fit into a pattern
51+
52+
E.g. minimum=123, maximum=456 cannot easily be made into a regex pattern
53+
therefore, (123, 456) is converted to
54+
[(123, 129), (130, 199), (200, 399), (400, 449), (450, 456)]
55+
which can be converted in get_subrange_pattern() to
56+
["12[3-9]", "(1[3-9][0-9]{1}", "[2-3][0-9]{2}", "4[0-4][0-9]{1}", "45[0-6]"]
57+
"""
58+
min_str = str(minimum).zfill(len(str(maximum)))
59+
max_str = str(maximum)
60+
61+
# if only the last digit varies, its a valid subrange
62+
if min_str[:-1] == max_str[:-1]:
63+
return [(minimum, maximum)]
64+
65+
# calculate the shared prefix between minimum and maximum and left-truncate it for now
66+
num_shared_prefix = len(
67+
list(itertools.takewhile(lambda x: x[0] == x[1], zip(min_str, max_str)))
68+
)
69+
shared_min = min_str[num_shared_prefix:]
70+
shared_max = max_str[num_shared_prefix:]
71+
prefix = min_str[:num_shared_prefix]
72+
73+
# determine how many trailing digits back are valid [0-9]
74+
# set first digit which doesn't qualify as the flex
75+
# then combine: {prefix}{flex}[0-9]{count}
76+
num_truncate = len(shared_min) - len(shared_min.rstrip("0")) + 1
77+
child_max = int(prefix + shared_min[:-num_truncate] + "9" * num_truncate)
78+
if child_max > maximum:
79+
child_max = int(prefix + shared_max[0] + "0" * len(shared_max[1:])) - 1
80+
81+
if child_max == maximum:
82+
return [(minimum, child_max)]
83+
return [(minimum, child_max)] + get_subranges(child_max + 1, maximum)
84+
85+
86+
def get_subrange_pattern(minimum, maximum):
87+
"""Convert (200, 399) to ([2-3][0-9]{2})"""
88+
89+
max_str = str(maximum)
90+
min_str = str(minimum).zfill(len(max_str))
91+
92+
last_range_zero = len(min_str) - re.search(r"[^0]|$", min_str[::-1]).start()
93+
last_range_nine = len(max_str) - re.search(r"[^9]|$", max_str[::-1]).start()
94+
full_range_start = max(last_range_zero, last_range_nine)
95+
96+
shared_prefix = min_str[: full_range_start - 1]
97+
range_digit_min, range_digit_max = (
98+
min_str[full_range_start - 1],
99+
max_str[full_range_start - 1],
100+
)
101+
102+
pattern = rf"{shared_prefix}[{range_digit_min}-{range_digit_max}]"
103+
104+
num_0_9_chars = len(max_str) - full_range_start
105+
if num_0_9_chars:
106+
pattern += rf"[0-9]{{{num_0_9_chars}}}"
107+
108+
return rf"({pattern})"
109+
110+
111+
def get_positive_int_range_pattern(minimum, maximum):
112+
assert minimum >= 0
113+
assert maximum >= 0
114+
115+
if minimum == 0:
116+
minimum = 1
117+
explicit_zero = True
118+
if maximum == 0:
119+
maximum = 1
120+
else:
121+
explicit_zero = False
122+
123+
if maximum == float("inf"):
124+
pseudo_maximum = 10 ** math.ceil(math.log10(minimum + 1)) - 1
125+
pseudo_pattern = "|".join(
126+
[
127+
get_subrange_pattern(sub_min, sub_max)
128+
for sub_min, sub_max in get_subranges(minimum, pseudo_maximum)
129+
]
130+
)
131+
pattern = rf"([\d]{{{len(str(pseudo_maximum))+1},}}|{pseudo_pattern})"
132+
else:
133+
pattern = "|".join(
134+
[
135+
get_subrange_pattern(sub_min, sub_max)
136+
for sub_min, sub_max in get_subranges(minimum, maximum)
137+
]
138+
)
139+
140+
if explicit_zero:
141+
pattern = rf"(0|({pattern}))"
142+
143+
return pattern
144+
145+
146+
def get_int_range_pattern(minimum=None, maximum=None):
147+
"""
148+
Create a pattern which matches all integers in range [minimum, maximum] *inclusive*
149+
"""
150+
if minimum is None:
151+
minimum = -float("inf")
152+
if maximum is None:
153+
maximum = float("inf")
154+
155+
if (minimum, maximum) == (-float("inf"), float("inf")):
156+
return INTEGER
157+
158+
assert minimum <= maximum
159+
160+
if minimum == maximum == 0:
161+
pattern = "0"
162+
elif minimum < 0 and maximum <= 0:
163+
abs_pattern = get_positive_int_range_pattern(max(abs(maximum), 1), abs(minimum))
164+
pattern = rf"-({abs_pattern})"
165+
if maximum == 0:
166+
pattern = rf"0|({pattern})"
167+
elif minimum < 0 and maximum > 0:
168+
minimum_pattern = get_positive_int_range_pattern(1, abs(minimum))
169+
maximum_pattern = get_positive_int_range_pattern(0, maximum)
170+
pattern = rf"(-({minimum_pattern}))|({maximum_pattern})"
171+
elif minimum >= 0 and maximum >= 0:
172+
pattern = get_positive_int_range_pattern(minimum, maximum)
173+
else:
174+
raise RuntimeError("This shouldn't occur, please open an issue")
175+
176+
return rf"({pattern})"
177+
178+
179+
def get_safe_int():
180+
"""10% larger than int64 range"""
181+
return get_int_range_pattern(minimum=-int(1e19), maximum=int(1e19))
182+
183+
184+
def build_regex_from_schema(
185+
schema: str, whitespace_pattern: Optional[str] = None, safe_subset: bool = True
186+
):
45187
"""Turn a JSON schema into a regex that matches any JSON object that follows
46188
this schema.
47189
@@ -60,6 +202,12 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non
60202
whitespace_pattern
61203
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
62204
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
205+
safe_subset
206+
Use a subset of json schema which performs better with language models.
207+
If you want to all the model to generate any json structure, set to False.
208+
Changes the following:
209+
- If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?")
210+
- If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19]
63211
64212
Returns
65213
-------
@@ -83,7 +231,7 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non
83231
resolver = registry.resolver()
84232

85233
content = schema.contents
86-
return to_regex(resolver, content, whitespace_pattern)
234+
return to_regex(resolver, content, whitespace_pattern, safe_subset)
87235

88236

89237
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
@@ -173,7 +321,10 @@ def validate_quantifiers(
173321

174322

175323
def to_regex(
176-
resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None
324+
resolver: Resolver,
325+
instance: dict,
326+
whitespace_pattern: Optional[str] = None,
327+
safe_subset: bool = True,
177328
):
178329
"""Translate a JSON Schema instance into a regex that validates the schema.
179330
@@ -196,11 +347,18 @@ def to_regex(
196347
whitespace_pattern
197348
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
198349
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
350+
safe_subset
351+
Use a subset of json schema which performs better with language models.
352+
If you want to all the model to generate any json structure, set to False.
353+
Changes the following:
354+
- If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?")
355+
- If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19]
356+
199357
"""
200358

201359
# set whitespace pattern
202360
if whitespace_pattern is None:
203-
whitespace_pattern = WHITESPACE
361+
whitespace_pattern = SAFE_WHITESPACE if safe_subset else WHITESPACE
204362

205363
if instance == {}:
206364
# JSON Schema Spec: Empty object means unconstrained, any json type is legal
@@ -213,7 +371,9 @@ def to_regex(
213371
{"type": "array"},
214372
{"type": "object"},
215373
]
216-
regexes = [to_regex(resolver, t, whitespace_pattern) for t in types]
374+
regexes = [
375+
to_regex(resolver, t, whitespace_pattern, safe_subset) for t in types
376+
]
217377
regexes = [rf"({r})" for r in regexes]
218378
return rf"{'|'.join(regexes)}"
219379

@@ -231,7 +391,7 @@ def to_regex(
231391
last_required_pos = max([i for i, value in enumerate(is_required) if value])
232392
for i, (name, value) in enumerate(properties.items()):
233393
subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}'
234-
subregex += to_regex(resolver, value, whitespace_pattern)
394+
subregex += to_regex(resolver, value, whitespace_pattern, safe_subset)
235395
if i < last_required_pos:
236396
subregex = f"{subregex}{whitespace_pattern},"
237397
elif i > last_required_pos:
@@ -245,7 +405,7 @@ def to_regex(
245405
property_subregexes = []
246406
for i, (name, value) in enumerate(properties.items()):
247407
subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}'
248-
subregex += to_regex(resolver, value, whitespace_pattern)
408+
subregex += to_regex(resolver, value, whitespace_pattern, safe_subset)
249409
property_subregexes.append(subregex)
250410
possible_patterns = []
251411
for i in range(len(property_subregexes)):
@@ -266,7 +426,8 @@ def to_regex(
266426
# given subschemas.
267427
elif "allOf" in instance:
268428
subregexes = [
269-
to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"]
429+
to_regex(resolver, t, whitespace_pattern, safe_subset)
430+
for t in instance["allOf"]
270431
]
271432
subregexes_str = [f"{subregex}" for subregex in subregexes]
272433
return rf"({''.join(subregexes_str)})"
@@ -275,15 +436,17 @@ def to_regex(
275436
# any (one or more) of the given subschemas.
276437
elif "anyOf" in instance:
277438
subregexes = [
278-
to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"]
439+
to_regex(resolver, t, whitespace_pattern, safe_subset)
440+
for t in instance["anyOf"]
279441
]
280442
return rf"({'|'.join(subregexes)})"
281443

282444
# To validate against oneOf, the given data must be valid against exactly
283445
# one of the given subschemas.
284446
elif "oneOf" in instance:
285447
subregexes = [
286-
to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"]
448+
to_regex(resolver, t, whitespace_pattern, safe_subset)
449+
for t in instance["oneOf"]
287450
]
288451

289452
xor_patterns = [f"(?:{subregex})" for subregex in subregexes]
@@ -293,7 +456,8 @@ def to_regex(
293456
# Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx
294457
elif "prefixItems" in instance:
295458
element_patterns = [
296-
to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"]
459+
to_regex(resolver, t, whitespace_pattern, safe_subset)
460+
for t in instance["prefixItems"]
297461
]
298462
comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}"
299463
tuple_inner = comma_split_pattern.join(element_patterns)
@@ -321,7 +485,7 @@ def to_regex(
321485
elif "$ref" in instance:
322486
path = f"{instance['$ref']}"
323487
instance = resolver.lookup(path).contents
324-
return to_regex(resolver, instance, whitespace_pattern)
488+
return to_regex(resolver, instance, whitespace_pattern, safe_subset)
325489

326490
# The type keyword may either be a string or an array:
327491
# - If it's a string, it is the name of one of the basic types.
@@ -366,6 +530,8 @@ def to_regex(
366530
return type_to_regex["string"]
367531

368532
elif instance_type == "number":
533+
# TODO: implement actualy json schema spec parameters: "maximum" and "minimum",
534+
# should be easy through extending get_int_range_pattern
369535
bounds = {
370536
"minDigitsInteger",
371537
"maxDigitsInteger",
@@ -405,12 +571,20 @@ def to_regex(
405571
return type_to_regex["number"]
406572

407573
elif instance_type == "integer":
408-
if "minDigits" in instance or "maxDigits" in instance:
409-
min_digits, max_digits = validate_quantifiers(
410-
instance.get("minDigits"), instance.get("maxDigits"), start_offset=1
574+
# TODO: Remove errors eventulaly - these keys aren't part of json schema spec
575+
if "maxDigits" in instance:
576+
raise ValueError(
577+
"'maxDigits' is not supported. Please use 'minimum' instead."
578+
)
579+
if "minDigits" in instance:
580+
raise ValueError(
581+
"'minDigits' is not supported. Please use 'minimum' instead."
411582
)
412-
return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})"
413-
return type_to_regex["integer"]
583+
584+
maximum = instance.get("maximum", int(1e19) if safe_subset else None)
585+
minimum = instance.get("minimum", -int(1e19) if safe_subset else None)
586+
587+
return get_int_range_pattern(minimum, maximum)
414588

415589
elif instance_type == "array":
416590
num_repeats = _get_num_items_pattern(

0 commit comments

Comments
 (0)