1
1
import inspect
2
+ import itertools
2
3
import json
4
+ import math
3
5
import re
4
6
import warnings
5
7
from typing import Callable , Optional , Tuple , Type , Union
18
20
NUMBER = rf"({ INTEGER } )(\.[0-9]+)?([eE][+-][0-9]+)?"
19
21
BOOLEAN = r"(true|false)"
20
22
NULL = r"null"
21
- WHITESPACE = r"[ ]?"
23
+ WHITESPACE = r"[\n\t ]*"
24
+ SAFE_WHITESPACE = r"[ ]?"
25
+
22
26
23
27
type_to_regex = {
24
28
"string" : STRING ,
25
- "integer" : INTEGER ,
26
29
"number" : NUMBER ,
27
30
"boolean" : BOOLEAN ,
28
31
"null" : NULL ,
32
+ "integer" : INTEGER ,
29
33
}
30
34
31
35
DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"'
41
45
}
42
46
43
47
44
- def build_regex_from_schema (schema : str , whitespace_pattern : Optional [str ] = None ):
48
+ def get_subranges (minimum , maximum ):
49
+ """
50
+ Convert a range into a list of subranges which can fit into a pattern
51
+
52
+ E.g. minimum=123, maximum=456 cannot easily be made into a regex pattern
53
+ therefore, (123, 456) is converted to
54
+ [(123, 129), (130, 199), (200, 399), (400, 449), (450, 456)]
55
+ which can be converted in get_subrange_pattern() to
56
+ ["12[3-9]", "(1[3-9][0-9]{1}", "[2-3][0-9]{2}", "4[0-4][0-9]{1}", "45[0-6]"]
57
+ """
58
+ min_str = str (minimum ).zfill (len (str (maximum )))
59
+ max_str = str (maximum )
60
+
61
+ # if only the last digit varies, its a valid subrange
62
+ if min_str [:- 1 ] == max_str [:- 1 ]:
63
+ return [(minimum , maximum )]
64
+
65
+ # calculate the shared prefix between minimum and maximum and left-truncate it for now
66
+ num_shared_prefix = len (
67
+ list (itertools .takewhile (lambda x : x [0 ] == x [1 ], zip (min_str , max_str )))
68
+ )
69
+ shared_min = min_str [num_shared_prefix :]
70
+ shared_max = max_str [num_shared_prefix :]
71
+ prefix = min_str [:num_shared_prefix ]
72
+
73
+ # determine how many trailing digits back are valid [0-9]
74
+ # set first digit which doesn't qualify as the flex
75
+ # then combine: {prefix}{flex}[0-9]{count}
76
+ num_truncate = len (shared_min ) - len (shared_min .rstrip ("0" )) + 1
77
+ child_max = int (prefix + shared_min [:- num_truncate ] + "9" * num_truncate )
78
+ if child_max > maximum :
79
+ child_max = int (prefix + shared_max [0 ] + "0" * len (shared_max [1 :])) - 1
80
+
81
+ if child_max == maximum :
82
+ return [(minimum , child_max )]
83
+ return [(minimum , child_max )] + get_subranges (child_max + 1 , maximum )
84
+
85
+
86
+ def get_subrange_pattern (minimum , maximum ):
87
+ """Convert (200, 399) to ([2-3][0-9]{2})"""
88
+
89
+ max_str = str (maximum )
90
+ min_str = str (minimum ).zfill (len (max_str ))
91
+
92
+ last_range_zero = len (min_str ) - re .search (r"[^0]|$" , min_str [::- 1 ]).start ()
93
+ last_range_nine = len (max_str ) - re .search (r"[^9]|$" , max_str [::- 1 ]).start ()
94
+ full_range_start = max (last_range_zero , last_range_nine )
95
+
96
+ shared_prefix = min_str [: full_range_start - 1 ]
97
+ range_digit_min , range_digit_max = (
98
+ min_str [full_range_start - 1 ],
99
+ max_str [full_range_start - 1 ],
100
+ )
101
+
102
+ pattern = rf"{ shared_prefix } [{ range_digit_min } -{ range_digit_max } ]"
103
+
104
+ num_0_9_chars = len (max_str ) - full_range_start
105
+ if num_0_9_chars :
106
+ pattern += rf"[0-9]{{{ num_0_9_chars } }}"
107
+
108
+ return rf"({ pattern } )"
109
+
110
+
111
+ def get_positive_int_range_pattern (minimum , maximum ):
112
+ assert minimum >= 0
113
+ assert maximum >= 0
114
+
115
+ if minimum == 0 :
116
+ minimum = 1
117
+ explicit_zero = True
118
+ if maximum == 0 :
119
+ maximum = 1
120
+ else :
121
+ explicit_zero = False
122
+
123
+ if maximum == float ("inf" ):
124
+ pseudo_maximum = 10 ** math .ceil (math .log10 (minimum + 1 )) - 1
125
+ pseudo_pattern = "|" .join (
126
+ [
127
+ get_subrange_pattern (sub_min , sub_max )
128
+ for sub_min , sub_max in get_subranges (minimum , pseudo_maximum )
129
+ ]
130
+ )
131
+ pattern = rf"([\d]{{{ len (str (pseudo_maximum ))+ 1 } ,}}|{ pseudo_pattern } )"
132
+ else :
133
+ pattern = "|" .join (
134
+ [
135
+ get_subrange_pattern (sub_min , sub_max )
136
+ for sub_min , sub_max in get_subranges (minimum , maximum )
137
+ ]
138
+ )
139
+
140
+ if explicit_zero :
141
+ pattern = rf"(0|({ pattern } ))"
142
+
143
+ return pattern
144
+
145
+
146
+ def get_int_range_pattern (minimum = None , maximum = None ):
147
+ """
148
+ Create a pattern which matches all integers in range [minimum, maximum] *inclusive*
149
+ """
150
+ if minimum is None :
151
+ minimum = - float ("inf" )
152
+ if maximum is None :
153
+ maximum = float ("inf" )
154
+
155
+ if (minimum , maximum ) == (- float ("inf" ), float ("inf" )):
156
+ return INTEGER
157
+
158
+ assert minimum <= maximum
159
+
160
+ if minimum == maximum == 0 :
161
+ pattern = "0"
162
+ elif minimum < 0 and maximum <= 0 :
163
+ abs_pattern = get_positive_int_range_pattern (max (abs (maximum ), 1 ), abs (minimum ))
164
+ pattern = rf"-({ abs_pattern } )"
165
+ if maximum == 0 :
166
+ pattern = rf"0|({ pattern } )"
167
+ elif minimum < 0 and maximum > 0 :
168
+ minimum_pattern = get_positive_int_range_pattern (1 , abs (minimum ))
169
+ maximum_pattern = get_positive_int_range_pattern (0 , maximum )
170
+ pattern = rf"(-({ minimum_pattern } ))|({ maximum_pattern } )"
171
+ elif minimum >= 0 and maximum >= 0 :
172
+ pattern = get_positive_int_range_pattern (minimum , maximum )
173
+ else :
174
+ raise RuntimeError ("This shouldn't occur, please open an issue" )
175
+
176
+ return rf"({ pattern } )"
177
+
178
+
179
+ def get_safe_int ():
180
+ """10% larger than int64 range"""
181
+ return get_int_range_pattern (minimum = - int (1e19 ), maximum = int (1e19 ))
182
+
183
+
184
+ def build_regex_from_schema (
185
+ schema : str , whitespace_pattern : Optional [str ] = None , safe_subset : bool = True
186
+ ):
45
187
"""Turn a JSON schema into a regex that matches any JSON object that follows
46
188
this schema.
47
189
@@ -60,6 +202,12 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non
60
202
whitespace_pattern
61
203
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
62
204
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
205
+ safe_subset
206
+ Use a subset of json schema which performs better with language models.
207
+ If you want to all the model to generate any json structure, set to False.
208
+ Changes the following:
209
+ - If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?")
210
+ - If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19]
63
211
64
212
Returns
65
213
-------
@@ -83,7 +231,7 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non
83
231
resolver = registry .resolver ()
84
232
85
233
content = schema .contents
86
- return to_regex (resolver , content , whitespace_pattern )
234
+ return to_regex (resolver , content , whitespace_pattern , safe_subset )
87
235
88
236
89
237
def convert_json_schema_to_str (json_schema : Union [dict , str , Type [BaseModel ]]) -> str :
@@ -173,7 +321,10 @@ def validate_quantifiers(
173
321
174
322
175
323
def to_regex (
176
- resolver : Resolver , instance : dict , whitespace_pattern : Optional [str ] = None
324
+ resolver : Resolver ,
325
+ instance : dict ,
326
+ whitespace_pattern : Optional [str ] = None ,
327
+ safe_subset : bool = True ,
177
328
):
178
329
"""Translate a JSON Schema instance into a regex that validates the schema.
179
330
@@ -196,11 +347,18 @@ def to_regex(
196
347
whitespace_pattern
197
348
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
198
349
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
350
+ safe_subset
351
+ Use a subset of json schema which performs better with language models.
352
+ If you want to all the model to generate any json structure, set to False.
353
+ Changes the following:
354
+ - If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?")
355
+ - If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19]
356
+
199
357
"""
200
358
201
359
# set whitespace pattern
202
360
if whitespace_pattern is None :
203
- whitespace_pattern = WHITESPACE
361
+ whitespace_pattern = SAFE_WHITESPACE if safe_subset else WHITESPACE
204
362
205
363
if instance == {}:
206
364
# JSON Schema Spec: Empty object means unconstrained, any json type is legal
@@ -213,7 +371,9 @@ def to_regex(
213
371
{"type" : "array" },
214
372
{"type" : "object" },
215
373
]
216
- regexes = [to_regex (resolver , t , whitespace_pattern ) for t in types ]
374
+ regexes = [
375
+ to_regex (resolver , t , whitespace_pattern , safe_subset ) for t in types
376
+ ]
217
377
regexes = [rf"({ r } )" for r in regexes ]
218
378
return rf"{ '|' .join (regexes )} "
219
379
@@ -231,7 +391,7 @@ def to_regex(
231
391
last_required_pos = max ([i for i , value in enumerate (is_required ) if value ])
232
392
for i , (name , value ) in enumerate (properties .items ()):
233
393
subregex = f'{ whitespace_pattern } "{ re .escape (name )} "{ whitespace_pattern } :{ whitespace_pattern } '
234
- subregex += to_regex (resolver , value , whitespace_pattern )
394
+ subregex += to_regex (resolver , value , whitespace_pattern , safe_subset )
235
395
if i < last_required_pos :
236
396
subregex = f"{ subregex } { whitespace_pattern } ,"
237
397
elif i > last_required_pos :
@@ -245,7 +405,7 @@ def to_regex(
245
405
property_subregexes = []
246
406
for i , (name , value ) in enumerate (properties .items ()):
247
407
subregex = f'{ whitespace_pattern } "{ name } "{ whitespace_pattern } :{ whitespace_pattern } '
248
- subregex += to_regex (resolver , value , whitespace_pattern )
408
+ subregex += to_regex (resolver , value , whitespace_pattern , safe_subset )
249
409
property_subregexes .append (subregex )
250
410
possible_patterns = []
251
411
for i in range (len (property_subregexes )):
@@ -266,7 +426,8 @@ def to_regex(
266
426
# given subschemas.
267
427
elif "allOf" in instance :
268
428
subregexes = [
269
- to_regex (resolver , t , whitespace_pattern ) for t in instance ["allOf" ]
429
+ to_regex (resolver , t , whitespace_pattern , safe_subset )
430
+ for t in instance ["allOf" ]
270
431
]
271
432
subregexes_str = [f"{ subregex } " for subregex in subregexes ]
272
433
return rf"({ '' .join (subregexes_str )} )"
@@ -275,15 +436,17 @@ def to_regex(
275
436
# any (one or more) of the given subschemas.
276
437
elif "anyOf" in instance :
277
438
subregexes = [
278
- to_regex (resolver , t , whitespace_pattern ) for t in instance ["anyOf" ]
439
+ to_regex (resolver , t , whitespace_pattern , safe_subset )
440
+ for t in instance ["anyOf" ]
279
441
]
280
442
return rf"({ '|' .join (subregexes )} )"
281
443
282
444
# To validate against oneOf, the given data must be valid against exactly
283
445
# one of the given subschemas.
284
446
elif "oneOf" in instance :
285
447
subregexes = [
286
- to_regex (resolver , t , whitespace_pattern ) for t in instance ["oneOf" ]
448
+ to_regex (resolver , t , whitespace_pattern , safe_subset )
449
+ for t in instance ["oneOf" ]
287
450
]
288
451
289
452
xor_patterns = [f"(?:{ subregex } )" for subregex in subregexes ]
@@ -293,7 +456,8 @@ def to_regex(
293
456
# Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx
294
457
elif "prefixItems" in instance :
295
458
element_patterns = [
296
- to_regex (resolver , t , whitespace_pattern ) for t in instance ["prefixItems" ]
459
+ to_regex (resolver , t , whitespace_pattern , safe_subset )
460
+ for t in instance ["prefixItems" ]
297
461
]
298
462
comma_split_pattern = rf"{ whitespace_pattern } ,{ whitespace_pattern } "
299
463
tuple_inner = comma_split_pattern .join (element_patterns )
@@ -321,7 +485,7 @@ def to_regex(
321
485
elif "$ref" in instance :
322
486
path = f"{ instance ['$ref' ]} "
323
487
instance = resolver .lookup (path ).contents
324
- return to_regex (resolver , instance , whitespace_pattern )
488
+ return to_regex (resolver , instance , whitespace_pattern , safe_subset )
325
489
326
490
# The type keyword may either be a string or an array:
327
491
# - If it's a string, it is the name of one of the basic types.
@@ -366,6 +530,8 @@ def to_regex(
366
530
return type_to_regex ["string" ]
367
531
368
532
elif instance_type == "number" :
533
+ # TODO: implement actualy json schema spec parameters: "maximum" and "minimum",
534
+ # should be easy through extending get_int_range_pattern
369
535
bounds = {
370
536
"minDigitsInteger" ,
371
537
"maxDigitsInteger" ,
@@ -405,12 +571,20 @@ def to_regex(
405
571
return type_to_regex ["number" ]
406
572
407
573
elif instance_type == "integer" :
408
- if "minDigits" in instance or "maxDigits" in instance :
409
- min_digits , max_digits = validate_quantifiers (
410
- instance .get ("minDigits" ), instance .get ("maxDigits" ), start_offset = 1
574
+ # TODO: Remove errors eventulaly - these keys aren't part of json schema spec
575
+ if "maxDigits" in instance :
576
+ raise ValueError (
577
+ "'maxDigits' is not supported. Please use 'minimum' instead."
578
+ )
579
+ if "minDigits" in instance :
580
+ raise ValueError (
581
+ "'minDigits' is not supported. Please use 'minimum' instead."
411
582
)
412
- return rf"(-)?(0|[1-9][0-9]{{{ min_digits } ,{ max_digits } }})"
413
- return type_to_regex ["integer" ]
583
+
584
+ maximum = instance .get ("maximum" , int (1e19 ) if safe_subset else None )
585
+ minimum = instance .get ("minimum" , - int (1e19 ) if safe_subset else None )
586
+
587
+ return get_int_range_pattern (minimum , maximum )
414
588
415
589
elif instance_type == "array" :
416
590
num_repeats = _get_num_items_pattern (
0 commit comments