Skip to content

Commit

Permalink
java stream with own data (#158)
Browse files Browse the repository at this point in the history
* updated imports on chat completions template

* Added java streaming with your own data sample

* updated test

---------

Co-authored-by: Chris Schraer <[email protected]>
  • Loading branch information
chschrae and Chris Schraer authored Feb 7, 2024
1 parent a98813c commit 6ef7020
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
<#@ parameter type="System.String" name="ClassName" #>
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.*;
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.core.credential.AzureKeyCredential;

import java.util.ArrayList;
Expand Down
15 changes: 15 additions & 0 deletions src/ai/.x/templates/openai-chat-streaming-with-data-java/_.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"_LongName": "OpenAI Chat Completions (w/ Data + AI Search)",
"_ShortName": "openai-chat-streaming-with-data",
"_Language": "Java",
"ClassName": "OpenAIChatCompletionsWithDataStreamingClass",
"AZURE_OPENAI_API_VERSION": "<insert your open api version here>",
"AZURE_OPENAI_ENDPOINT": "<insert your OpenAI endpoint here>",
"AZURE_OPENAI_KEY": "<insert your OpenAI API key here>",
"AZURE_OPENAI_CHAT_DEPLOYMENT": "<insert your OpenAI chat deployment name here>",
"AZURE_OPENAI_EMBEDDING_DEPLOYMENT": "<insert your OpenAI embeddings deployment name here>",
"AZURE_OPENAI_SYSTEM_PROMPT": "You are a helpful AI assistant.",
"AZURE_AI_SEARCH_ENDPOINT": "<insert your search endpoint here>",
"AZURE_AI_SEARCH_KEY": "<insert your search api key here>",
"AZURE_AI_SEARCH_INDEX_NAME": "<insert your search index name here>"
}
39 changes: 39 additions & 0 deletions src/ai/.x/templates/openai-chat-streaming-with-data-java/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>com.azure.ai.openai.samples</groupId>
<artifactId>openai-chat-java-streaming</artifactId>
<version>1.0-SNAPSHOT</version>

<dependencies>
<!-- https://mvnrepository.com/artifact/com.azure/azure-ai-openai -->
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-ai-openai</artifactId>
<version>1.0.0-beta.6</version>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<version>3.1.2</version>
<executions>
<execution>
<id>copy-dependencies</id>
<phase>prepare-package</phase>
<goals>
<goal>copy-dependencies</goal>
</goals>
<configuration>
<outputDirectory>${project.build.directory}/lib</outputDirectory>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mvn clean package
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
javac -cp target/lib/* src/OpenAIChatCompletionsWithDataStreamingClass.java src/Main.java -d out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
java -cp out;target/lib/* Main
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
<#@ template hostspecific="true" #>
<#@ output extension=".java" encoding="utf-8" #>
<#@ parameter type="System.String" name="ClassName" #>
<#@ parameter type="System.String" name="AZURE_OPENAI_API_VERSION" #>
<#@ parameter type="System.String" name="AZURE_OPENAI_ENDPOINT" #>
<#@ parameter type="System.String" name="AZURE_OPENAI_KEY" #>
<#@ parameter type="System.String" name="AZURE_OPENAI_CHAT_DEPLOYMENT" #>
<#@ parameter type="System.String" name="AZURE_OPENAI_EMBEDDING_DEPLOYMENT" #>
<#@ parameter type="System.String" name="AZURE_OPENAI_SYSTEM_PROMPT" #>
<#@ parameter type="System.String" name="AZURE_AI_SEARCH_ENDPOINT" #>
<#@ parameter type="System.String" name="AZURE_AI_SEARCH_KEY" #>
<#@ parameter type="System.String" name="AZURE_AI_SEARCH_INDEX_NAME" #>
import java.util.Scanner;
import reactor.core.publisher.Flux;
import com.azure.ai.openai.models.ChatCompletions;

public class Main {

public static void main(String[] args) {
String openAIKey = (System.getenv("AZURE_OPENAI_KEY") != null) ? System.getenv("AZURE_OPENAI_KEY") : "<insert your OpenAI API key here>";
String openAIEndpoint = (System.getenv("AZURE_OPENAI_ENDPOINT") != null) ? System.getenv("AZURE_OPENAI_ENDPOINT") : "<insert your OpenAI endpoint here>";
String openAIChatDeployment = (System.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT") != null) ? System.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT") : "<insert your OpenAI chat deployment name here>";
String openAISystemPrompt = (System.getenv("AZURE_OPENAI_SYSTEM_PROMPT") != null) ? System.getenv("AZURE_OPENAI_SYSTEM_PROMPT") : "You are a helpful AI assistant.";

String openAIApiVersion = System.getenv("AZURE_OPENAI_API_VERSION") != null ? System.getenv("AZURE_OPENAI_API_VERSION") : "<#= AZURE_OPENAI_API_VERSION #>";
String azureSearchEmbeddingsDeploymentName = System.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") != null ? System.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") : "<#= AZURE_OPENAI_EMBEDDING_DEPLOYMENT #>";
String azureSearchEndpoint = System.getenv("AZURE_AI_SEARCH_ENDPOINT") != null ? System.getenv("AZURE_AI_SEARCH_ENDPOINT") : "<#= AZURE_AI_SEARCH_ENDPOINT #>";
String azureSearchAPIKey = System.getenv("AZURE_AI_SEARCH_KEY") != null ? System.getenv("AZURE_AI_SEARCH_KEY") : "<#= AZURE_AI_SEARCH_KEY #>";
String azureSearchIndexName = System.getenv("AZURE_AI_SEARCH_INDEX_NAME") != null ? System.getenv("AZURE_AI_SEARCH_INDEX_NAME") : "<#= AZURE_AI_SEARCH_INDEX_NAME #>";

<#= ClassName #> chat = new <#= ClassName #>(openAIKey, openAIEndpoint, openAIChatDeployment, openAISystemPrompt, azureSearchEndpoint, azureSearchIndexName, azureSearchAPIKey, azureSearchEmbeddingsDeploymentName);

Scanner scanner = new Scanner(System.in);
while (true) {
System.out.print("User: ");
String userPrompt = scanner.nextLine();
if (userPrompt.isEmpty() || "exit".equals(userPrompt))
break;

System.out.print("\nAssistant: ");
Flux<ChatCompletions> responseFlux = chat.getChatCompletionsStreamingAsync(userPrompt, update -> {
System.out.print(update.getContent());
});
responseFlux.blockLast();
System.out.println("\n");
}
scanner.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
<#@ template hostspecific="true" #>
<#@ output extension=".java" encoding="utf-8" #>
<#@ parameter type="System.String" name="ClassName" #>
import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.AzureCognitiveSearchChatExtensionConfiguration;
import com.azure.ai.openai.models.AzureCognitiveSearchChatExtensionParameters;
import com.azure.ai.openai.models.AzureCognitiveSearchIndexFieldMappingOptions;
import com.azure.ai.openai.models.AzureCognitiveSearchQueryType;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.ChatResponseMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.OnYourDataApiKeyAuthenticationOptions;
import com.azure.ai.openai.models.OnYourDataDeploymentNameVectorizationSource;
import com.azure.core.credential.AzureKeyCredential;
import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.Consumer;
import java.util.List;

public class <#= ClassName #> {

private OpenAIAsyncClient client;
private ChatCompletionsOptions options;
private String openAIChatDeployment;
private String openAISystemPrompt;

public <#= ClassName #> (
String openAIKey,
String openAIEndpoint,
String openAIChatDeployment,
String openAISystemPrompt,
String azureSearchEndpoint,
String azureSearchIndexName,
String azureSearchAPIKey,
String azureSearchEmbeddingsDeploymentName) {

this.openAIChatDeployment = openAIChatDeployment;
this.openAISystemPrompt = openAISystemPrompt;
client = new OpenAIClientBuilder()
.endpoint(openAIEndpoint)
.credential(new AzureKeyCredential(openAIKey))
.buildAsyncClient();

AzureCognitiveSearchChatExtensionConfiguration searchConfiguration =
new AzureCognitiveSearchChatExtensionConfiguration(
new AzureCognitiveSearchChatExtensionParameters(azureSearchEndpoint, azureSearchIndexName)
.setAuthentication(new OnYourDataApiKeyAuthenticationOptions(azureSearchAPIKey))
.setQueryType(AzureCognitiveSearchQueryType.VECTOR_SIMPLE_HYBRID)
.setEmbeddingDependency(new OnYourDataDeploymentNameVectorizationSource(azureSearchEmbeddingsDeploymentName))
);

List<ChatRequestMessage> chatMessages = new ArrayList<>();
options = new ChatCompletionsOptions(chatMessages)
.setDataSources(Arrays.asList(searchConfiguration));
ClearConversation();
options.setStream(true);
}

public void ClearConversation(){
List<ChatRequestMessage> chatMessages = options.getMessages();
chatMessages.clear();
chatMessages.add(new ChatRequestSystemMessage(this.openAISystemPrompt));
}

public Flux<ChatCompletions> getChatCompletionsStreamingAsync(String userPrompt,
Consumer<ChatResponseMessage> callback) {
options.getMessages().add(new ChatRequestUserMessage(userPrompt));

StringBuilder responseContent = new StringBuilder();
Flux<ChatCompletions> response = client.getChatCompletionsStream(this.openAIChatDeployment, options);

response.subscribe(chatResponse -> {
if (chatResponse.getChoices() != null) {
for (ChatChoice update : chatResponse.getChoices()) {
if (update.getDelta() == null || update.getDelta().getContent() == null)
continue;
String content = update.getDelta().getContent();

if (update.getFinishReason() == CompletionsFinishReason.CONTENT_FILTERED) {
content = content + "\nWARNING: Content filtered!";
} else if (update.getFinishReason() == CompletionsFinishReason.TOKEN_LIMIT_REACHED) {
content = content + "\nERROR: Exceeded token limit!";
}

if (content.isEmpty())
continue;

if(callback != null) {
callback.accept(update.getDelta());
}
responseContent.append(content);
}

options.getMessages().add(new ChatRequestAssistantMessage(responseContent.toString()));
}
});

return response;
}
}
2 changes: 1 addition & 1 deletion tests/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,5 @@
^Helper +Function +Class +Library +helper-functions +C# *\r?$\n
^OpenAI +Chat +Completions +openai-chat +C#, +Go, +Java, +JavaScript, +Python *\r?$\n
^OpenAI +Chat +Completions +\(Streaming\) +openai-chat-streaming +C#, +Go, +Java, +JavaScript, +Python *\r?$\n
^OpenAI +Chat +Completions +\(w/ +Data +\+ +AI +Search\) +openai-chat-streaming-with-data +C#, +Go, +JavaScript, +Python *\r?$\n
^OpenAI +Chat +Completions +\(w/ +Data +\+ +AI +Search\) +openai-chat-streaming-with-data +C#, +Go, +Java, +JavaScript, +Python *\r?$\n
^OpenAI +Chat +Completions +\(w/ +Functions\) +openai-chat-streaming-with-functions +C#, +Go, +JavaScript, +Python *\r?$\n

0 comments on commit 6ef7020

Please sign in to comment.