Skip to content

Commit 3c3b2e7

Browse files
authored
Merge pull request #285 from gs-snagaraj/main
Support OpenAIFunction Custom object Schema
2 parents 63db2ab + b71458a commit 3c3b2e7

File tree

6 files changed

+160
-10
lines changed

6 files changed

+160
-10
lines changed

Diff for: aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIFunction.java

+28-4
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@
77
import com.fasterxml.jackson.core.JsonProcessingException;
88
import com.fasterxml.jackson.databind.JsonNode;
99
import com.fasterxml.jackson.databind.ObjectMapper;
10+
import com.microsoft.semantickernel.exceptions.SKException;
11+
import com.microsoft.semantickernel.orchestration.responseformat.ResponseSchemaGenerator;
1012
import com.microsoft.semantickernel.semanticfunctions.InputVariable;
1113
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionMetadata;
14+
import org.apache.commons.lang3.StringUtils;
1215
import java.util.ArrayList;
1316
import java.util.Collections;
1417
import java.util.HashMap;
1518
import java.util.List;
1619
import java.util.Locale;
1720
import java.util.Map;
21+
import java.util.Objects;
22+
import java.util.concurrent.ConcurrentHashMap;
1823
import java.util.stream.Collectors;
1924
import javax.annotation.Nonnull;
2025
import javax.annotation.Nullable;
@@ -159,14 +164,17 @@ private static String getSchemaForFunctionParameter(@Nullable InputVariable para
159164
entries.add("\"type\":\"" + type + "\"");
160165

161166
// Add description if present
167+
String description =null;
162168
if (parameter != null && parameter.getDescription() != null && !parameter.getDescription()
163169
.isEmpty()) {
164-
String description = parameter.getDescription();
170+
description = parameter.getDescription();
165171
description = description.replaceAll("\\r?\\n|\\r", "");
166172
description = description.replace("\"", "\\\"");
167-
168-
description = String.format("\"description\":\"%s\"", description);
169-
entries.add(description);
173+
entries.add(String.format("\"description\":\"%s\"", description));
174+
}
175+
// If custom type, generate schema
176+
if("object".equalsIgnoreCase(type)) {
177+
return getObjectSchema(parameter.getType(), description);
170178
}
171179

172180
// Add enum options if parameter is an enum
@@ -219,4 +227,20 @@ private static String getJavaTypeToOpenAiFunctionType(String javaType) {
219227
return "object";
220228
}
221229
}
230+
231+
private static String getObjectSchema(String type, String description){
232+
String schema= "{ \"type\" : \"object\" }";
233+
try {
234+
Class<?> clazz = Class.forName(type);
235+
schema = ResponseSchemaGenerator.jacksonGenerator().generateSchema(clazz);
236+
237+
} catch (ClassNotFoundException | SKException ignored) {
238+
239+
}
240+
Map<String, Object> properties = BinaryData.fromString(schema).toObject(Map.class);
241+
if(StringUtils.isNotBlank(description)) {
242+
properties.put("description", description);
243+
}
244+
return BinaryData.fromObject(properties).toString();
245+
}
222246
}

Diff for: aiservices/openai/src/test/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/JsonSchemaTest.java

+89
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
// Copyright (c) Microsoft. All rights reserved.
22
package com.microsoft.semantickernel.aiservices.openai.chatcompletion;
33

4+
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
45
import com.fasterxml.jackson.core.JsonProcessingException;
56
import com.microsoft.semantickernel.orchestration.responseformat.JsonSchemaResponseFormat;
7+
import com.microsoft.semantickernel.plugin.KernelPlugin;
8+
import com.microsoft.semantickernel.plugin.KernelPluginFactory;
9+
import com.microsoft.semantickernel.semanticfunctions.KernelFunction;
10+
import com.microsoft.semantickernel.semanticfunctions.annotations.DefineKernelFunction;
11+
import com.microsoft.semantickernel.semanticfunctions.annotations.KernelFunctionParameter;
612
import org.junit.jupiter.api.Assertions;
713
import org.junit.jupiter.api.Test;
14+
import reactor.core.publisher.Mono;
815

916
public class JsonSchemaTest {
1017

@@ -24,4 +31,86 @@ public void jacksonGenerationTest() throws JsonProcessingException {
2431
"\"type\":\"object\",\"properties\":{\"bar\":{}}"));
2532
}
2633

34+
@Test
35+
public void openAIFunctionTest() {
36+
KernelPlugin plugin = KernelPluginFactory.createFromObject(
37+
new TestPlugin(),
38+
"test");
39+
40+
Assertions.assertNotNull(plugin);
41+
Assertions.assertEquals(plugin.getName(), "test");
42+
Assertions.assertEquals(plugin.getFunctions().size(), 3);
43+
44+
KernelFunction<?> testFunction = plugin.getFunctions()
45+
.get("asyncPersonFunction");
46+
OpenAIFunction openAIFunction = OpenAIFunction.build(
47+
testFunction.getMetadata(),
48+
plugin.getName());
49+
50+
String parameters = "{\"type\":\"object\",\"required\":[\"person\",\"input\"],\"properties\":{\"input\":{\"type\":\"string\",\"description\":\"input string\"},\"person\":{\"type\":\"object\",\"properties\":{\"age\":{\"type\":\"integer\",\"description\":\"The age of the person.\"},\"name\":{\"type\":\"string\",\"description\":\"The name of the person.\"},\"title\":{\"type\":\"string\",\"enum\":[\"MS\",\"MRS\",\"MR\"],\"description\":\"The title of the person.\"}},\"required\":[\"age\",\"name\",\"title\"],\"additionalProperties\":false,\"description\":\"input person\"}}}";
51+
Assertions.assertEquals(parameters, openAIFunction.getFunctionDefinition().getParameters().toString());
52+
53+
}
54+
55+
56+
public static class TestPlugin {
57+
58+
@DefineKernelFunction
59+
public String testFunction(
60+
@KernelFunctionParameter(name = "input", description = "input string") String input) {
61+
return "test" + input;
62+
}
63+
64+
@DefineKernelFunction(returnType = "int")
65+
public Mono<Integer> asyncTestFunction(
66+
@KernelFunctionParameter(name = "input") String input) {
67+
return Mono.just(1);
68+
}
69+
70+
@DefineKernelFunction(returnType = "int", description = "test function description",
71+
name = "asyncPersonFunction", returnDescription = "test return description")
72+
public Mono<Integer> asyncPersonFunction(
73+
@KernelFunctionParameter(name = "person",description = "input person", type = Person.class) Person person,
74+
@KernelFunctionParameter(name = "input", description = "input string") String input) {
75+
return Mono.just(1);
76+
}
77+
}
78+
79+
private static enum Title {
80+
MS,
81+
MRS,
82+
MR
83+
}
84+
85+
public static class Person {
86+
@JsonPropertyDescription("The name of the person.")
87+
private String name;
88+
@JsonPropertyDescription("The age of the person.")
89+
private int age;
90+
@JsonPropertyDescription("The title of the person.")
91+
private Title title;
92+
93+
94+
public Person(String name, int age) {
95+
this.name = name;
96+
this.age = age;
97+
}
98+
99+
public String getName() {
100+
return name;
101+
}
102+
103+
public int getAge() {
104+
return age;
105+
}
106+
107+
public Title getTitle() {
108+
return title;
109+
}
110+
111+
public void setTitle(Title title) {
112+
this.title = title;
113+
}
114+
}
115+
27116
}

Diff for: samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/App.java

+2-6
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ public static void main(String[] args) throws Exception {
7070
ChatCompletionService.class);
7171

7272
ContextVariableTypes
73-
.addGlobalConverter(ContextVariableTypeConverter.builder(LightModel.class)
74-
.toPromptString(new Gson()::toJson)
75-
.build());
73+
.addGlobalConverter(new LightModelTypeConverter());
7674

7775
KernelHooks hook = new KernelHooks();
7876

@@ -99,9 +97,7 @@ public static void main(String[] args) throws Exception {
9997
InvocationContext invocationContext = new Builder()
10098
.withReturnMode(InvocationReturnMode.LAST_MESSAGE_ONLY)
10199
.withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(true))
102-
.withContextVariableConverter(ContextVariableTypeConverter.builder(LightModel.class)
103-
.toPromptString(new Gson()::toJson)
104-
.build())
100+
.withContextVariableConverter(new LightModelTypeConverter())
105101
.build();
106102

107103
// Create a history to store the conversation

Diff for: samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightModel.java

+7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
// Copyright (c) Microsoft. All rights reserved.
22
package com.microsoft.semantickernel.samples.demos.lights;
33

4+
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
5+
46
public class LightModel {
57

8+
@JsonPropertyDescription("The unique identifier of the light")
69
private int id;
10+
11+
@JsonPropertyDescription("The name of the light")
712
private String name;
13+
14+
@JsonPropertyDescription("The state of the light")
815
private Boolean isOn;
916

1017
public LightModel(int id, String name, Boolean isOn) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.microsoft.semantickernel.samples.demos.lights;
2+
3+
import com.google.gson.Gson;
4+
import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter;
5+
6+
public class LightModelTypeConverter extends ContextVariableTypeConverter<LightModel> {
7+
private static final Gson gson = new Gson();
8+
9+
public LightModelTypeConverter() {
10+
super(
11+
LightModel.class,
12+
obj -> {
13+
if(obj instanceof String) {
14+
return gson.fromJson((String)obj, LightModel.class);
15+
} else {
16+
return gson.fromJson(gson.toJson(obj), LightModel.class);
17+
}
18+
},
19+
(types, lightModel) -> gson.toJson(lightModel),
20+
json -> gson.fromJson(json, LightModel.class)
21+
);
22+
}
23+
}

Diff for: samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightsPlugin.java

+11
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ public List<LightModel> getLights() {
2424
return lights;
2525
}
2626

27+
@DefineKernelFunction(name = "add_light", description = "Adds a new light")
28+
public String addLight(
29+
@KernelFunctionParameter(name = "newLight", description = "new Light Details", type = LightModel.class) LightModel light) {
30+
if( light != null) {
31+
System.out.println("Adding light " + light.getName());
32+
lights.add(light);
33+
return "Light added";
34+
}
35+
return "Light failed to added";
36+
}
37+
2738
@DefineKernelFunction(name = "change_state", description = "Changes the state of the light")
2839
public LightModel changeState(
2940
@KernelFunctionParameter(name = "id", description = "The ID of the light to change", type = int.class) int id,

0 commit comments

Comments
 (0)