Skip to content

Commit c5714fa

Browse files
aschwarlouf
authored andcommitted
Support date,time,date-time,uuid formats
1 parent 8a0bafc commit c5714fa

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

outlines/fsm/json_schema.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@
2424
"null": NULL,
2525
}
2626

27+
DATE_TIME = r"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"
28+
DATE = r"(?:\d{4})-(?:0[1-9]|1[0-2])-(?:0[1-9]|[1-2][0-9]|3[0-1])"
29+
TIME = r"(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\\.[0-9]+)?(Z)?"
30+
UUID = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
31+
32+
format_to_regex = {
33+
"uuid": UUID,
34+
"date-time": DATE_TIME,
35+
"date": DATE,
36+
"time": TIME,
37+
}
38+
2739

2840
def build_regex_from_object(object: Union[str, Callable, BaseModel]):
2941
"""Turn a JSON schema into a regex that matches any JSON object that follows
@@ -210,6 +222,20 @@ def to_regex(resolver: Resolver, instance: dict):
210222
return rf'(^"{pattern[1:-1]}"$)'
211223
else:
212224
return rf'("{pattern}")'
225+
elif "format" in instance:
226+
format = instance["format"]
227+
if format == "date-time":
228+
return format_to_regex["date-time"]
229+
elif format == "uuid":
230+
return format_to_regex["uuid"]
231+
elif format == "date":
232+
return format_to_regex["date"]
233+
elif format == "time":
234+
return format_to_regex["time"]
235+
else:
236+
raise NotImplementedError(
237+
f"Format {format} is not supported by Outlines"
238+
)
213239
else:
214240
return type_to_regex["string"]
215241

tests/fsm/test_json_schema.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77

88
from outlines.fsm.json_schema import (
99
BOOLEAN,
10+
DATE,
11+
DATE_TIME,
1012
INTEGER,
1113
NULL,
1214
NUMBER,
1315
STRING,
1416
STRING_INNER,
17+
TIME,
18+
UUID,
1519
build_regex_from_object,
1620
get_schema_from_signature,
1721
to_regex,
@@ -451,3 +455,75 @@ def test_match(schema, regex, examples):
451455
assert match.span() == (0, len(string))
452456
else:
453457
assert match is None
458+
459+
460+
@pytest.mark.parametrize(
461+
"schema,regex,examples",
462+
[
463+
# UUID
464+
(
465+
{"title": "Foo", "type": "string", "format": "uuid"},
466+
UUID,
467+
[
468+
("123e4567-e89b-12d3-a456-426614174000", True),
469+
("123e4567-e89b-12d3-a456-42661417400", False),
470+
("123e4567-e89b-12d3-a456-42661417400g", False),
471+
("123e4567-e89b-12d3-a456-42661417400-", False),
472+
("", False),
473+
],
474+
),
475+
# DATE-TIME
476+
(
477+
{"title": "Foo", "type": "string", "format": "date-time"},
478+
DATE_TIME,
479+
[
480+
("2018-11-13T20:20:39Z", True),
481+
("2016-09-18T17:34:02.666Z", True),
482+
("2008-05-11T15:30:00Z", True),
483+
("2021-01-01T00:00:00", True),
484+
("2022-01-10 07:19:30", False), # missing T
485+
("2022-12-10T10-04-29", False), # incorrect separator
486+
("2023-01-01", False),
487+
],
488+
),
489+
# DATE
490+
(
491+
{"title": "Foo", "type": "string", "format": "date"},
492+
DATE,
493+
[
494+
("2018-11-13", True),
495+
("2016-09-18", True),
496+
("2008-05-11", True),
497+
("2015-13-01", False), # incorrect month
498+
("2022-01", False), # missing day
499+
("2022/12/01", False), # incorrect separator"
500+
],
501+
),
502+
# TIME
503+
(
504+
{"title": "Foo", "type": "string", "format": "time"},
505+
TIME,
506+
[
507+
("20:20:39Z", True),
508+
("15:30:00Z", True),
509+
("25:30:00", False), # incorrect hour
510+
("15:30", False), # missing seconds
511+
("15:30:00.000", False), # missing Z
512+
("15-30-00", False), # incorrect separator
513+
("15:30:00+01:00", False), # incorrect separator
514+
],
515+
),
516+
],
517+
)
518+
def test_format(schema, regex, examples):
519+
schema = json.dumps(schema)
520+
test_regex = build_regex_from_object(schema)
521+
assert test_regex == regex
522+
523+
for string, does_match in examples:
524+
match = re.fullmatch(test_regex, string)
525+
if does_match:
526+
assert match[0] == string
527+
assert match.span() == (0, len(string))
528+
else:
529+
assert match is None

0 commit comments

Comments
 (0)