Skip to content

Commit 31a53b1

Browse files
committed
Structured output
Structured output, either low-level by providing a JSON Schema (as a string), or higher-level by providing a struct that “looks like” we want the output to be. ```matlab >> chat = openAIChat; >> prototype = struct(commonName = "dog", scientificName = "Canis familiaris") prototype = struct with fields: commonName: "dog" scientificName: "Canis familiaris" >> generate(chat,"Which animal produces honey?",... ResponseFormat = prototype) ans = struct with fields: commonName: "Honeybee" scientificName: "Apis mellifera" ```
1 parent 178d995 commit 31a53b1

24 files changed

+1096
-397
lines changed

+llms/+azure/validateResponseFormat.m

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
function validateResponseFormat(format,model,messages)
2+
%validateResponseFormat - validate requested response format is available for selected API Version
3+
4+
% Copyright 2024 The MathWorks, Inc.
5+
6+
if ischar(format) | iscellstr(format) %#ok<ISCLSTR>
7+
format = string(format);
8+
end
9+
10+
if isstring(format) && isequal(lower(format),"json")
11+
if nargin > 2
12+
% OpenAI requires that the prompt or message describing the format must contain the word `"json"` or `"JSON"`.
13+
if ~any(cellfun(@(s) contains(s.content,"json","IgnoreCase",true), messages))
14+
error("llms:warningJsonInstruction", ...
15+
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
16+
end
17+
end
18+
end
19+
20+
if requestsStructuredOutput(format)
21+
% the beauty of ISO-8601: comparing dates by string comparison
22+
if model.APIVersion < "2024-08-01"
23+
error("llms:structuredOutputRequiresAPI", ...
24+
llms.utils.errorMessageCatalog.getMessage("llms:structuredOutputRequiresAPI", model.APIVersion));
25+
end
26+
end
27+
end
28+
29+
function tf = requestsStructuredOutput(format)
30+
% If the response format is not "text" or "json", then the input is interpreted as structured output.
31+
tf = ~isequal(format, "text") & ~isequal(format, "json");
32+
end

+llms/+internal/callAzureChatAPI.m

+12
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@
104104
parameters.tool_choice = nvp.ToolChoice;
105105
end
106106

107+
if strcmp(nvp.ResponseFormat,"json")
108+
parameters.response_format = struct('type','json_object');
109+
elseif isstruct(nvp.ResponseFormat)
110+
parameters.response_format = struct('type','json_schema',...
111+
'json_schema', struct('strict', true, 'name', 'computedFromPrototype', ...
112+
'schema', llms.internal.jsonSchemaFromPrototype(nvp.ResponseFormat)));
113+
elseif startsWith(string(nvp.ResponseFormat), asManyOfPattern(whitespacePattern)+"{")
114+
parameters.response_format = struct('type','json_schema',...
115+
'json_schema', struct('strict', true, 'name', 'providedInCall', ...
116+
'schema', llms.internal.verbatimJSON(nvp.ResponseFormat)));
117+
end
118+
107119
if ~isempty(nvp.Seed)
108120
parameters.seed = nvp.Seed;
109121
end

+llms/+internal/callOpenAIChatAPI.m

+8
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@
116116

117117
if strcmp(nvp.ResponseFormat,"json")
118118
parameters.response_format = struct('type','json_object');
119+
elseif isstruct(nvp.ResponseFormat)
120+
parameters.response_format = struct('type','json_schema',...
121+
'json_schema', struct('strict', true, 'name', 'computedFromPrototype', ...
122+
'schema', llms.internal.jsonSchemaFromPrototype(nvp.ResponseFormat)));
123+
elseif startsWith(string(nvp.ResponseFormat), asManyOfPattern(whitespacePattern)+"{")
124+
parameters.response_format = struct('type','json_schema',...
125+
'json_schema', struct('strict', true, 'name', 'providedInCall', ...
126+
'schema', llms.internal.verbatimJSON(nvp.ResponseFormat)));
119127
end
120128

121129
if ~isempty(nvp.Seed)
+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
function schema = jsonSchemaFromPrototype(prototype)
2+
% This function is undocumented and will change in a future release
3+
4+
%jsonSchemaFromPrototype Create a JSON Schema matching given prototype
5+
6+
% Copyright 2024 The MathWorks Inc.
7+
8+
if ~isstruct(prototype)
9+
error("llms:incorrectResponseFormat", ...
10+
llms.utils.errorMessageCatalog.getMessage("llms:incorrectResponseFormat"));
11+
end
12+
13+
% OpenAI requires top-level to be "type":"object"
14+
if ~isscalar(prototype)
15+
prototype = struct("result",{prototype});
16+
end
17+
18+
schema = recursiveSchemaFromPrototype(prototype);
19+
end
20+
21+
function schema = recursiveSchemaFromPrototype(prototype)
22+
if ~isscalar(prototype)
23+
schema = struct("type","array","items",recursiveSchemaFromPrototype(prototype(1)));
24+
elseif isstruct(prototype)
25+
schema = schemaFromStruct(prototype);
26+
elseif isstring(prototype) || iscellstr(prototype)
27+
schema = struct("type","string");
28+
elseif isinteger(prototype)
29+
schema = struct("type","integer");
30+
elseif isnumeric(prototype)
31+
schema = struct("type","number");
32+
elseif islogical(prototype)
33+
schema = struct("type","boolean");
34+
elseif iscategorical(prototype)
35+
schema = struct("type","string", ...
36+
"enum",{categories(prototype)});
37+
elseif ismissing(prototype)
38+
schema = struct("type","null");
39+
else
40+
error("llms:unsupportedDatatypeInPrototype", ...
41+
llms.utils.errorMessageCatalog.getMessage("llms:unsupportedDatatypeInPrototype", class(prototype)));
42+
end
43+
end
44+
45+
function schema = schemaFromStruct(prototype)
46+
fields = string(fieldnames(prototype));
47+
48+
properties = struct();
49+
for fn=fields(:).'
50+
properties.(fn) = recursiveSchemaFromPrototype(prototype.(fn));
51+
end
52+
53+
% to make jsonencode encode an array
54+
if isscalar(fields)
55+
fields = {{fields}};
56+
end
57+
58+
schema = struct( ...
59+
"type","object", ...
60+
"properties",properties, ...
61+
"required",fields, ...
62+
"additionalProperties",false);
63+
end

+llms/+internal/reformatOutput.m

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
function result = reformatOutput(result,responseFormat)
2+
% This function is undocumented and will change in a future release
3+
4+
%reformatOutput - Create the expected struct for structured output
5+
6+
% Copyright 2024 The MathWorks, Inc.
7+
8+
if isstruct(responseFormat)
9+
try
10+
result = jsondecode(result);
11+
catch
12+
error("llms:apiReturnedIncompleteJSON",llms.utils.errorMessageCatalog.getMessage("llms:apiReturnedIncompleteJSON",result))
13+
end
14+
end
15+
if isstruct(responseFormat) && ~isscalar(responseFormat)
16+
result = result.result;
17+
end
18+
if isstruct(responseFormat)
19+
result = llms.internal.useSameFieldTypes(result,responseFormat);
20+
end
21+
end

+llms/+internal/sendRequest.m

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@
4242
response = send(request, matlab.net.URI(endpoint),httpOpts,consumer);
4343
streamedText = consumer.ResponseText;
4444
end
45-
end
45+
end

+llms/+internal/useSameFieldTypes.m

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
function data = useSameFieldTypes(data,prototype)
2+
% This function is undocumented and will change in a future release
3+
4+
%useSameFieldTypes Change struct field data types to match prototype
5+
6+
% Copyright 2024 The MathWorks Inc.
7+
8+
if ~isscalar(data)
9+
data = arrayfun( ...
10+
@(d) llms.internal.useSameFieldTypes(d,prototype), data, ...
11+
UniformOutput=false);
12+
data = vertcat(data{:});
13+
return
14+
end
15+
16+
data = alignTypes(data, prototype);
17+
end
18+
19+
function data = alignTypes(data, prototype)
20+
switch class(prototype)
21+
case "struct"
22+
prototype = prototype(1);
23+
if isscalar(data)
24+
if isequal(fieldnames(data),fieldnames(prototype))
25+
for field_c = fieldnames(data).'
26+
field = field_c{1};
27+
data.(field) = alignTypes(data.(field),prototype.(field));
28+
end
29+
end
30+
else
31+
data = arrayfun(@(d) alignTypes(d,prototype), data, UniformOutput=false);
32+
data = vertcat(data{:});
33+
end
34+
case "string"
35+
data = string(data);
36+
case "categorical"
37+
data = categorical(string(data),categories(prototype));
38+
case "missing"
39+
data = missing;
40+
otherwise
41+
data = cast(data,"like",prototype);
42+
end
43+
end

+llms/+internal/verbatimJSON.m

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
classdef verbatimJSON
2+
% This class is undocumented and will change in a future release
3+
4+
% Copyright 2024 The MathWorks, Inc.
5+
properties
6+
Value (1,1) string
7+
end
8+
methods
9+
function obj = verbatimJSON(str)
10+
obj.Value = str;
11+
end
12+
function json = jsonencode(obj,varargin)
13+
json = obj.Value;
14+
end
15+
end
16+
end
+23-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,36 @@
1-
function validateResponseFormat(format,model)
2-
%validateResponseFormat - validate requested response format is available for selected model
1+
function validateResponseFormat(format,model,messages)
2+
%validateResponseFormat - Validate requested response format is available for selected model
33
% Not all OpenAI models support JSON output
44

55
% Copyright 2024 The MathWorks, Inc.
66

7-
if format == "json"
7+
if ischar(format) | iscellstr(format) %#ok<ISCLSTR>
8+
format = string(format);
9+
end
10+
11+
if isequal(format, "json")
812
if ismember(model,["gpt-4","gpt-4-0613","o1-preview","o1-mini"])
913
error("llms:invalidOptionAndValueForModel", ...
1014
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", model));
15+
elseif nargin > 2
16+
% OpenAI says you need to mention JSON somewhere in the input
17+
if ~any(cellfun(@(s) contains(s.content,"json","IgnoreCase",true), messages))
18+
error("llms:warningJsonInstruction", ...
19+
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
20+
end
1121
else
1222
warning("llms:warningJsonInstruction", ...
1323
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
1424
end
25+
elseif requestsStructuredOutput(format)
26+
if ~startsWith(model,"gpt-4o")
27+
error("llms:noStructuredOutputForModel", ...
28+
llms.utils.errorMessageCatalog.getMessage("llms:noStructuredOutputForModel", model));
29+
end
1530
end
1631
end
32+
33+
function tf = requestsStructuredOutput(format)
34+
% If the response format is not "text" or "json", then the input is interpreted as structured output.
35+
tf = ~isequal(format, "text") & ~isequal(format, "json");
36+
end

+llms/+utils/errorMessageCatalog.m

+6
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,19 @@
5151
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, ToolChoice must not be specified.";
5252
catalog("llms:mustBeMessagesOrTxt") = "Message must be nonempty string, character array, cell array of character vectors, or messageHistory object.";
5353
catalog("llms:invalidOptionAndValueForModel") = "'{1}' with value '{2}' is not supported for model ""{3}"".";
54+
catalog("llms:noStructuredOutputForModel") = "Structured output is not supported for model ""{1}"".";
55+
catalog("llms:noStructuredOutputForAzureDeployment") = "Structured output is not supported for deployment ""{1}"".";
56+
catalog("llms:structuredOutputRequiresAPI") = "Structured output is not supported for API version ""{1}"". Use APIVersion=""2024-08-01-preview"" or newer.";
5457
catalog("llms:invalidOptionForModel") = "Invalid argument name {1} for model ""{2}"".";
5558
catalog("llms:invalidContentTypeForModel") = "{1} is not supported for model ""{2}"".";
5659
catalog("llms:functionNotAvailableForModel") = "Image editing is not supported for model ""{1}"".";
5760
catalog("llms:promptLimitCharacter") = "Prompt must contain at most {1} characters for model ""{2}"".";
5861
catalog("llms:pngExpected") = "Image must be a PNG file (*.png).";
5962
catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message.";
6063
catalog("llms:apiReturnedError") = "Server returned error indicating: ""{1}""";
64+
catalog("llms:apiReturnedIncompleteJSON") = "Generated output is not valid JSON: ""{1}""";
6165
catalog("llms:dimensionsMustBeSmallerThan") = "Dimensions must be less than or equal to {1}.";
6266
catalog("llms:stream:responseStreamer:InvalidInput") = "Input does not have the expected json format, got ""{1}"".";
67+
catalog("llms:unsupportedDatatypeInPrototype") = "Invalid data type ''{1}'' in prototype. Prototype must be a struct, composed of numerical, string, logical, categorical, or struct.";
68+
catalog("llms:incorrectResponseFormat") = "Invalid response format. Response format must be ""text"", ""json"", a struct, or a string with a JSON Schema definition.";
6369
end

+llms/+utils/mustBeResponseFormat.m

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
function mustBeResponseFormat(format)
2+
% This function is undocumented and will change in a future release
3+
4+
% Copyright 2024 The MathWorks Inc.
5+
if isstring(format) || ischar(format) || iscellstr(format)
6+
mustBeTextScalar(format);
7+
if ~ismember(format,["text","json"]) && ...
8+
~startsWith(format,asManyOfPattern(whitespacePattern)+"{")
9+
error("llms:incorrectResponseFormat", ...
10+
llms.utils.errorMessageCatalog.getMessage("llms:incorrectResponseFormat"));
11+
end
12+
elseif ~isstruct(format)
13+
error("llms:incorrectResponseFormat", ...
14+
llms.utils.errorMessageCatalog.getMessage("llms:incorrectResponseFormat"));
15+
end
16+
end

+llms/jsonSchemaFromPrototype.m

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
function str = jsonSchemaFromPrototype(prototype)
2+
%jsonSchemaFromPrototype - create JSON Schema from prototype
3+
% STR = llms.jsonSchemaFromPrototype(PROTOTYPE) creates a JSON Schema
4+
% that can be used with openAIChat ResponseFormat.
5+
%
6+
% Example:
7+
% >> prototype = struct("name","Alena Zlatkov","age",32);
8+
% >> schema = llms.jsonSchemaFromPrototype(prototype);
9+
% >> generate(openAIChat, "Generate a random person", ResponseFormat=schema)
10+
%
11+
% ans = "{"name":"Emily Carter","age":29}"
12+
13+
% Copyright 2024 The MathWorks, Inc.
14+
15+
str = string(jsonencode(llms.internal.jsonSchemaFromPrototype(prototype),PrettyPrint=true));
16+
end

0 commit comments

Comments
 (0)