Skip to content

Commit 2f877cf

Browse files
committed
Support enums with different types
1 parent d2f1c9c commit 2f877cf

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

outlines/text/json_schema.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def to_regex(resolver: Resolver, instance: dict):
6363
- Handle types defined as a list
6464
- Handle constraints on numbers
6565
- Handle special patterns: `date`, `uri`, etc.
66+
- Handle optional fields (not in `required`)
67+
68+
This does not support recursive definitions.
6669
6770
Parameters
6871
----------
@@ -116,12 +119,14 @@ def to_regex(resolver: Resolver, instance: dict):
116119
# The enum keyword is used to restrict a value to a fixed set of values. It
117120
# must be an array with at least one element, where each element is unique.
118121
elif "enum" in instance:
119-
if instance["type"] == "string":
120-
choices = [f'"{re.escape(choice)}"' for choice in instance["enum"]]
121-
return f"({'|'.join(choices)})"
122-
else:
123-
choices = [re.escape(str(choice)) for choice in instance["enum"]]
124-
return f"({'|'.join(choices)})"
122+
choices = []
123+
for choice in instance["enum"]:
124+
if type(choice) in [int, float, bool, None]:
125+
choices.append(re.escape(str(choice)))
126+
elif type(choice) == str:
127+
choices.append(f'"{re.escape(choice)}"')
128+
129+
return f"({'|'.join(choices)})"
125130

126131
elif "$ref" in instance:
127132
path = f"{instance['$ref']}"
@@ -134,8 +139,8 @@ def to_regex(resolver: Resolver, instance: dict):
134139
# the name of one of the basic types, and each element is unique. In this
135140
# case, the JSON snippet is valid if it matches any of the given types.
136141
elif "type" in instance:
137-
type = instance["type"]
138-
if type == "string":
142+
instance_type = instance["type"]
143+
if instance_type == "string":
139144
if "maxLength" in instance or "minLength" in instance:
140145
max_length = instance.get("maxLength", "")
141146
min_length = instance.get("minLength", "")
@@ -156,13 +161,13 @@ def to_regex(resolver: Resolver, instance: dict):
156161
else:
157162
return type_to_regex["string"]
158163

159-
elif type == "number":
164+
elif instance_type == "number":
160165
return type_to_regex["number"]
161166

162-
elif type == "integer":
167+
elif instance_type == "integer":
163168
return type_to_regex["integer"]
164169

165-
elif type == "array":
170+
elif instance_type == "array":
166171
if "items" in instance:
167172
items_regex = to_regex(resolver, instance["items"])
168173
return rf"\[({items_regex})(,({items_regex}))*\]"
@@ -180,17 +185,19 @@ def to_regex(resolver: Resolver, instance: dict):
180185
regexes = [to_regex(resolver, t) for t in types]
181186
return rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)}))*\]"
182187

183-
elif type == "boolean":
188+
elif instance_type == "boolean":
184189
return type_to_regex["boolean"]
185190

186-
elif type == "null":
191+
elif instance_type == "null":
187192
return type_to_regex["null"]
188193

189-
elif isinstance(type, list):
194+
elif isinstance(instance_type, list):
190195
# Here we need to make the choice to exclude generating an object
191196
# if the specification of the object is not give, even though a JSON
192197
# object that contains an object here would be valid under the specification.
193-
regexes = [to_regex(resolver, {"type": t}) for t in type if t != "object"]
198+
regexes = [
199+
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
200+
]
194201
return rf"({'|'.join(regexes)})"
195202

196203
raise NotImplementedError(

0 commit comments

Comments
 (0)