Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Feature/streamable adapter (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
multipletwigs authored May 11, 2024
1 parent 5b0b8db commit bc38e24
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { AdapterRequest } from "./adapter";

export interface StreamableAdapter {
generateStreamableResponse(args: AdapterRequest): Promise<string>;
generateStreamableResponse(args: AdapterRequest): AsyncGenerator<string>;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,36 @@ describe("GPT-4 Adapter", () => {
// Assert the message content to contain the word akeru
expect(result.toLocaleLowerCase()).toContain("akeru");
});

test("Returns streamable GPT-4 chat completions response", async () => {
// Arrange
const messages = [
{
role: "user" as Role,
content: "hello, who are you?",
},
];
const assistant_instructions =
"You're an AI assistant. You're job is to help the user. Always respond with the word akeru.";

const gpt4Adapter = AdapterManager.instance.getStreamableAdapter("gpt-4");

if (!gpt4Adapter) {
throw new Error("GPT-4 adapter not found");
}

const result = gpt4Adapter.generateStreamableResponse({
message_content: messages,
instruction: assistant_instructions,
});

const responses = [];
for await (const response of result) {
responses.push(response);
}

// Assert the message content to contain the word akeru
const generatedResponse = responses.join("");
expect(generatedResponse.toLocaleLowerCase()).toContain("akeru");
});
});
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { Message } from "@/core/domain/messages";
import { BaseAdapter, StreamableAdapter } from "@/infrastructure/adapters/BaseAdapter";
import {
BaseAdapter,
StreamableAdapter,
} from "@/infrastructure/adapters/BaseAdapter";
import { AdapterRequest } from "../adapter";
import { GPTModels } from "./models";

Expand All @@ -17,35 +20,6 @@ interface ChatResponseChoice {
export interface OpenAIResponse {
choices: ChatResponseChoice[];
}

// export async function gpt4Adapter(
// messages: Pick<Message, "role" | "content">[],
// assistant_instructions: string
// ): Promise<unknown> {
// // System will always be the assistant_instruction that created the assistant
// const gpt_messages = [
// { role: "system", content: assistant_instructions },
// ].concat(messages);
// try {
// const res = await fetch("https://api.openai.com/v1/chat/completions", {
// method: "POST",
// headers: {
// "Content-Type": "application/json",
// Authorization: `Bearer ${process.env.OPENAI_API_KEY}`,
// },
// body: JSON.stringify({
// model: "gpt-4",
// messages: gpt_messages,
// }),
// });

// const data: OpenAIResponse = await res.json();
// return data;
// } catch (error) {
// return new Response("Error", { status: 500 });
// }
// }

export class GPTAdapter extends BaseAdapter implements StreamableAdapter {
adapterName: string;
adapterDescription = "This adapter supports all adapter models from OpenAI";
Expand All @@ -57,7 +31,9 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter {
}

async generateSingleResponse(args: AdapterRequest): Promise<string> {
const gpt_messages = [{ role: "system", content: args.instruction }].concat(args.message_content);
const gpt_messages = [{ role: "system", content: args.instruction }].concat(
args.message_content
);
try {
const res = await fetch(this.OPENAI_ENDPOINT, {
method: "POST",
Expand All @@ -72,15 +48,43 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter {
});

const data: OpenAIResponse = await res.json();
const finished_inference = data.choices[0].message.content
const finished_inference = data.choices[0].message.content;
return Promise.resolve(finished_inference);
} catch (error) {
return Promise.reject(error);
}
}

async generateStreamableResponse(args: AdapterRequest): Promise<string> {
const gpt_messages = [{ role: "system", content: args.instruction }].concat(args.message_content);
async *chunksToLines(chunksAsync: any) {
let previous = "";
for await (const chunk of chunksAsync) {
const bufferChunk = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk);
previous += bufferChunk;
let eolIndex;
while ((eolIndex = previous.indexOf("\n")) >= 0) {
// line includes the EOL
const line = previous.slice(0, eolIndex + 1).trimEnd();
if (line === "data: [DONE]") break;
if (line.startsWith("data: ")) yield line;
previous = previous.slice(eolIndex + 1);
}
}
}

async *linesToMessages(linesAsync: any) {
for await (const line of linesAsync) {
const message = line.substring("data :".length);

yield message;
}
}

async *generateStreamableResponse(
args: AdapterRequest
): AsyncGenerator<string> {
const gpt_messages = [{ role: "system", content: args.instruction }].concat(
args.message_content
);
try {
const res = await fetch(this.OPENAI_ENDPOINT, {
method: "POST",
Expand All @@ -91,17 +95,32 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter {
body: JSON.stringify({
model: this.adapterName,
messages: gpt_messages,
stream: true
stream: true,
}),
});

const data: OpenAIResponse = await res.json();
const finished_inference = data.choices[0].message.content
return Promise.resolve(finished_inference);
// This section of the code is taken from...
/// https://github.com/openai/openai-node/issues/18
const reader = res.body?.getReader();
while (true && reader) {
const { done, value } = await reader.read();
if (done) break;

const data = new TextDecoder().decode(value);
const chunkToLines = this.chunksToLines(data);
const linesToMessages = this.linesToMessages(chunkToLines);

for await (const message of linesToMessages) {
const messageObject: any = JSON.parse(message);

// MessageObject is the response from the OpenAI API streams
const messageToYield = messageObject.choices[0].delta.content;
if (messageToYield) yield messageToYield;
}
}
} catch (error) {
return Promise.reject(error);
}

}

getAdapterInformation(): Object {
Expand Down

0 comments on commit bc38e24

Please sign in to comment.