From 053efd7e7018a49875e4d3ddfe8e3c7cf843a70d Mon Sep 17 00:00:00 2001 From: Zach Khong Date: Thu, 9 May 2024 20:20:52 +0800 Subject: [PATCH 1/4] clean up for adapters for both streamable adapters and nonstreamable ones --- packages/akeru-server/package.json | 1 + packages/akeru-server/scripts/cleanDB.ts | 2 +- .../assistant/assistantController.test.ts | 2 +- .../application/controllers/userController.ts | 2 +- .../application/services/assistantService.ts | 2 +- .../application/services/messageService.ts | 2 +- .../core/application/services/runService.ts | 46 +++---- .../application/services/threadService.ts | 2 +- .../core/application/services/tokenService.ts | 2 +- .../core/application/services/userService.ts | 2 +- .../akeru-server/src/core/domain/assistant.ts | 4 +- packages/akeru-server/src/core/domain/run.ts | 3 +- .../src/core/domain/validators.ts | 7 -- .../adaptaters/adaptermanager.ts | 69 ----------- .../adaptaters/openai/gpt4Adapter.test.ts | 29 ----- .../adaptaters/openai/gpt4Adapter.ts | 42 ------- .../infrastructure/adapters/AdapterManager.ts | 49 ++++++++ .../infrastructure/adapters/BaseAdapter.ts | 28 +++++ .../src/infrastructure/adapters/adapter.ts | 23 ++++ .../adapters/openai/GPTAdapter.test.ts | 31 +++++ .../adapters/openai/GPTAdapter.ts | 113 ++++++++++++++++++ .../infrastructure/adapters/openai/models.ts | 1 + .../{adaptaters => adapters}/redisAdapter.ts | 0 .../adapters/validators/models.ts | 1 + .../validators/validatorAdapter.test.ts | 4 +- .../validators/validatorAdapter.ts | 4 +- 26 files changed, 282 insertions(+), 189 deletions(-) delete mode 100644 packages/akeru-server/src/core/domain/validators.ts delete mode 100644 packages/akeru-server/src/infrastructure/adaptaters/adaptermanager.ts delete mode 100644 packages/akeru-server/src/infrastructure/adaptaters/openai/gpt4Adapter.test.ts delete mode 100644 packages/akeru-server/src/infrastructure/adaptaters/openai/gpt4Adapter.ts create mode 100644 packages/akeru-server/src/infrastructure/adapters/AdapterManager.ts create mode 100644 packages/akeru-server/src/infrastructure/adapters/BaseAdapter.ts create mode 100644 packages/akeru-server/src/infrastructure/adapters/adapter.ts create mode 100644 packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts create mode 100644 packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts create mode 100644 packages/akeru-server/src/infrastructure/adapters/openai/models.ts rename packages/akeru-server/src/infrastructure/{adaptaters => adapters}/redisAdapter.ts (100%) create mode 100644 packages/akeru-server/src/infrastructure/adapters/validators/models.ts rename packages/akeru-server/src/infrastructure/{adaptaters => adapters}/validators/validatorAdapter.test.ts (88%) rename packages/akeru-server/src/infrastructure/{adaptaters => adapters}/validators/validatorAdapter.ts (93%) diff --git a/packages/akeru-server/package.json b/packages/akeru-server/package.json index 8323dde..5a9a5ba 100644 --- a/packages/akeru-server/package.json +++ b/packages/akeru-server/package.json @@ -8,6 +8,7 @@ }, "dependencies": { "@elysiajs/bearer": "^1.0.2", + "@elysiajs/stream": "^1.0.2", "@elysiajs/swagger": "^1.0.4", "dayjs": "^1.11.10", "elysia": "latest", diff --git a/packages/akeru-server/scripts/cleanDB.ts b/packages/akeru-server/scripts/cleanDB.ts index 52ec317..0636f0e 100644 --- a/packages/akeru-server/scripts/cleanDB.ts +++ b/packages/akeru-server/scripts/cleanDB.ts @@ -1,4 +1,4 @@ -import { redis } from "@/infrastructure/adaptaters/redisAdapter"; +import { redis } from "@/infrastructure/adapters/redisAdapter"; const script = async () => { try { diff --git a/packages/akeru-server/src/core/application/controllers/assistant/assistantController.test.ts b/packages/akeru-server/src/core/application/controllers/assistant/assistantController.test.ts index 6d34080..c8f2aea 100644 --- a/packages/akeru-server/src/core/application/controllers/assistant/assistantController.test.ts +++ b/packages/akeru-server/src/core/application/controllers/assistant/assistantController.test.ts @@ -1,6 +1,6 @@ import { createSuperAdminForTesting } from "@/__tests__/utils"; import { app } from "@/index"; -import { redis } from "@/infrastructure/adaptaters/redisAdapter"; +import { redis } from "@/infrastructure/adapters/redisAdapter"; import { test, expect, describe } from "bun:test"; import { UNAUTHORIZED_MISSING_TOKEN } from "../../ports/returnValues"; diff --git a/packages/akeru-server/src/core/application/controllers/userController.ts b/packages/akeru-server/src/core/application/controllers/userController.ts index 88373b9..c8ae55f 100644 --- a/packages/akeru-server/src/core/application/controllers/userController.ts +++ b/packages/akeru-server/src/core/application/controllers/userController.ts @@ -1,6 +1,6 @@ import { assignRole, createUser } from "core/application/services/userService"; import { createToken } from "@/core/application/services/tokenService"; -import { redis } from "@/infrastructure/adaptaters/redisAdapter"; +import { redis } from "@/infrastructure/adapters/redisAdapter"; import { ulid } from "ulid"; /** diff --git a/packages/akeru-server/src/core/application/services/assistantService.ts b/packages/akeru-server/src/core/application/services/assistantService.ts index 3ea6e3d..a6f0969 100644 --- a/packages/akeru-server/src/core/application/services/assistantService.ts +++ b/packages/akeru-server/src/core/application/services/assistantService.ts @@ -1,5 +1,5 @@ import type { Assistant } from "@/core/domain/assistant"; -import { redis } from "@/infrastructure/adaptaters/redisAdapter"; +import { redis } from "@/infrastructure/adapters/redisAdapter"; /** * Creates an assistant in Redis if it does not exist and adds a 'CREATED_BY' relationship to the user. diff --git a/packages/akeru-server/src/core/application/services/messageService.ts b/packages/akeru-server/src/core/application/services/messageService.ts index 176d260..b74a724 100644 --- a/packages/akeru-server/src/core/application/services/messageService.ts +++ b/packages/akeru-server/src/core/application/services/messageService.ts @@ -1,8 +1,8 @@ import { v4 as uuidv4 } from "uuid"; -import { redis } from "@/infrastructure/adaptaters/redisAdapter"; import { Message } from "@/core/domain/messages"; import dayjs from "dayjs"; import { getUserRole } from "./userService"; +import { redis } from "@/infrastructure/adapters/redisAdapter"; export async function createMessage( userId: string, diff --git a/packages/akeru-server/src/core/application/services/runService.ts b/packages/akeru-server/src/core/application/services/runService.ts index ef3b057..93913c8 100644 --- a/packages/akeru-server/src/core/application/services/runService.ts +++ b/packages/akeru-server/src/core/application/services/runService.ts @@ -3,15 +3,12 @@ import { ThreadRun, ThreadRunRequest } from "@/core/domain/run"; import { getAssistantData } from "./assistantService"; import { getThread } from "./threadService"; -import { getAllMessage } from "./messageService"; +import { createMessage, getAllMessage } from "./messageService"; import { Role } from "@/core/domain/roles"; +import { AdapterManager } from "@/infrastructure/adapters/AdapterManager"; -import { - AdapterManager, - GPT4AdapterManager, - ValidatorAdapterManager, -} from "@/infrastructure/adaptaters/adaptermanager"; +import { v4 as uuidv4 } from "uuid"; export async function runAssistantWithThread(runData: ThreadRunRequest) { // get all messages from the thread, and run it over to the assistant to get a response @@ -36,33 +33,28 @@ export async function runAssistantWithThread(runData: ThreadRunRequest) { }; }); - let adapterManager: AdapterManager; + const adapter = AdapterManager.instance.getBaseAdapter(assistantData.model); - // Calls the appropriate adapter based on what model the assistant uses - if (assistantData.model === "gpt-4") { - adapterManager = new GPT4AdapterManager(); - } else if (assistantData.model === "llama-2-7b-chat-int8") { - adapterManager = new ValidatorAdapterManager("llama-2-7b-chat-int8"); - } else { - throw new Error("Unsupported assistant model"); + if (!adapter) { + throw new Error("Adapter not found"); } // Generate response using the appropriate adapter - const assistantResponse: string = await adapterManager.generateResponse( - everyRoleAndContent, - assistantData.instruction - ); + const assistantResponse: string = await adapter.generateSingleResponse({ + message_content: everyRoleAndContent, + instruction: assistantData.instruction, + }); // Add assistant response to the thread - await adapterManager.createUserMessage( - assistant_id, - thread_id, - assistantResponse - ); - - // Create thread run response - const threadRunResponse: ThreadRun = - await adapterManager.createThreadRunResponse(assistant_id, thread_id); + await createMessage(assistant_id, thread_id, assistantResponse); + + // Return the thread run response + const threadRunResponse: ThreadRun = { + assistant_id: assistant_id, + thread_id: thread_id, + id: uuidv4(), + created_at: new Date(), + }; return threadRunResponse; } diff --git a/packages/akeru-server/src/core/application/services/threadService.ts b/packages/akeru-server/src/core/application/services/threadService.ts index f4ef5e6..dbe3c2a 100644 --- a/packages/akeru-server/src/core/application/services/threadService.ts +++ b/packages/akeru-server/src/core/application/services/threadService.ts @@ -1,5 +1,5 @@ import { Thread } from "@/core/domain/thread"; -import { redis } from "@/infrastructure/adaptaters/redisAdapter"; +import { redis } from "@/infrastructure/adapters/redisAdapter"; /** * Creates a new thread in Redis. diff --git a/packages/akeru-server/src/core/application/services/tokenService.ts b/packages/akeru-server/src/core/application/services/tokenService.ts index 57a017b..7b9e804 100644 --- a/packages/akeru-server/src/core/application/services/tokenService.ts +++ b/packages/akeru-server/src/core/application/services/tokenService.ts @@ -1,4 +1,4 @@ -import { redis } from "@/infrastructure/adaptaters/redisAdapter"; +import { redis } from "@/infrastructure/adapters/redisAdapter"; import { getUserPermissions } from "@/core/application/services/userService"; import jwt from "jsonwebtoken"; import { PermissionDetailArray } from "@/core/domain/permissions"; diff --git a/packages/akeru-server/src/core/application/services/userService.ts b/packages/akeru-server/src/core/application/services/userService.ts index 399628b..8f1247c 100644 --- a/packages/akeru-server/src/core/application/services/userService.ts +++ b/packages/akeru-server/src/core/application/services/userService.ts @@ -2,7 +2,7 @@ import { Role, getRolePermissions } from "@/core/domain/roles"; import { HumanUserBody, User, UserRole } from "@/core/domain/user"; -import { redis } from "@/infrastructure/adaptaters/redisAdapter"; +import { redis } from "@/infrastructure/adapters/redisAdapter"; import { ulid } from "ulid"; /** diff --git a/packages/akeru-server/src/core/domain/assistant.ts b/packages/akeru-server/src/core/domain/assistant.ts index 44452d9..283d001 100644 --- a/packages/akeru-server/src/core/domain/assistant.ts +++ b/packages/akeru-server/src/core/domain/assistant.ts @@ -1,9 +1,9 @@ -import { ModelsType } from "./validators"; +import { SupportedModels } from "@/infrastructure/adapters/adapter"; export type Assistant = { id: string; name: string; - model: ModelsType; + model: SupportedModels; tools: { type: string }[]; fileIds: string[]; instruction: string; diff --git a/packages/akeru-server/src/core/domain/run.ts b/packages/akeru-server/src/core/domain/run.ts index e2fe21e..5bf597a 100644 --- a/packages/akeru-server/src/core/domain/run.ts +++ b/packages/akeru-server/src/core/domain/run.ts @@ -9,6 +9,7 @@ export interface ThreadRun { created_at: Date; assistant_id: Assistant["id"]; thread_id: Thread["id"]; + stream?: boolean } -export type ThreadRunRequest = Pick; +export type ThreadRunRequest = Pick; diff --git a/packages/akeru-server/src/core/domain/validators.ts b/packages/akeru-server/src/core/domain/validators.ts deleted file mode 100644 index adcbf93..0000000 --- a/packages/akeru-server/src/core/domain/validators.ts +++ /dev/null @@ -1,7 +0,0 @@ -// types for all the models we'll have on the subnet -// chaining like this: "BtModels & Models" gives a type of never - -type BtModels = "bt-gpt-4"; -type Models = "gpt-4" | "gpt-3.5" | "llama-2-7b-chat-int8"; - -export type ModelsType = BtModels | Models; diff --git a/packages/akeru-server/src/infrastructure/adaptaters/adaptermanager.ts b/packages/akeru-server/src/infrastructure/adaptaters/adaptermanager.ts deleted file mode 100644 index fefc24e..0000000 --- a/packages/akeru-server/src/infrastructure/adaptaters/adaptermanager.ts +++ /dev/null @@ -1,69 +0,0 @@ -import { ThreadRun } from "@/core/domain/run"; -import { v4 as uuidv4 } from "uuid"; -import { gpt4Adapter } from "./openai/gpt4Adapter"; -import { validatorAdapter } from "./validators/validatorAdapter"; - -import { createMessage } from "@/core/application/services/messageService"; -import { ModelsType } from "@/core/domain/validators"; - -export abstract class AdapterManager { - abstract generateResponse( - everyRoleAndContent: { role: string; content: string }[], - instruction: any - ): Promise; - - async createUserMessage( - assistant_id: string, - thread_id: string, - assistantResponse: string - ): Promise { - await createMessage(assistant_id, thread_id, assistantResponse); - } - - async createThreadRunResponse( - assistant_id: string, - thread_id: string - ): Promise { - const threadRunResponse: ThreadRun = { - id: uuidv4(), - assistant_id: assistant_id, - thread_id: thread_id, - created_at: new Date(), - }; - return threadRunResponse; - } -} - -export class GPT4AdapterManager extends AdapterManager { - async generateResponse( - everyRoleAndContent: any, - instruction: any - ): Promise { - const gpt4AdapterRes: any = await gpt4Adapter( - everyRoleAndContent, - instruction - ); - return gpt4AdapterRes.choices[0].message.content; - } -} - -export class ValidatorAdapterManager extends AdapterManager { - private modelName: ModelsType; - - constructor(modelName: ModelsType) { - super(); - this.modelName = modelName; - } - - async generateResponse( - everyRoleAndContent: any, - instruction: any - ): Promise { - const validatorAdapterRes: any = await validatorAdapter( - everyRoleAndContent, - this.modelName, - instruction - ); - return validatorAdapterRes.choices[0].message.content; - } -} diff --git a/packages/akeru-server/src/infrastructure/adaptaters/openai/gpt4Adapter.test.ts b/packages/akeru-server/src/infrastructure/adaptaters/openai/gpt4Adapter.test.ts deleted file mode 100644 index 5f85539..0000000 --- a/packages/akeru-server/src/infrastructure/adaptaters/openai/gpt4Adapter.test.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { test, expect, describe } from "bun:test"; -import { OpenAIResponse, gpt4Adapter } from "./gpt4Adapter"; -import { Role } from "@/core/domain/roles"; - -describe("GPT-4 Adapter", () => { - test("Returns GPT-4 chat completions response", async () => { - // Arrange - const messages = [ - { - role: "user" as Role, - content: "hello, who are you?", - }, - ]; - const assistant_instructions = - "You're an AI assistant. You're job is to help the user. Always respond with the word akeru."; - // Act - const result = (await gpt4Adapter( - messages, - assistant_instructions - )) as OpenAIResponse; - - // Assert the message content to contain the word akeru - expect(result.choices[0].message.content.toLocaleLowerCase()).toContain( - "akeru" - ); - }); -}); - - diff --git a/packages/akeru-server/src/infrastructure/adaptaters/openai/gpt4Adapter.ts b/packages/akeru-server/src/infrastructure/adaptaters/openai/gpt4Adapter.ts deleted file mode 100644 index 9cbc0cf..0000000 --- a/packages/akeru-server/src/infrastructure/adaptaters/openai/gpt4Adapter.ts +++ /dev/null @@ -1,42 +0,0 @@ -import { Message } from "@/core/domain/messages"; - -interface ChatMessage { - content: string, - role: "assistant" | "system" | "user" -} - -interface ChatResponseChoice{ - finish_reason: string, - index: number, - message: ChatMessage -} - -export interface OpenAIResponse { - choices: ChatResponseChoice[] -} - -export async function gpt4Adapter( - messages: Pick[], - assistant_instructions: string -): Promise { - // System will always be the assistant_instruction that created the assistant - const gpt_messages = [{role: "system", content: assistant_instructions}].concat(messages) - try { - const res = await fetch("https://api.openai.com/v1/chat/completions", { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${process.env.OPENAI_API_KEY}`, - }, - body: JSON.stringify({ - model: "gpt-4", - messages: gpt_messages - }), - }); - - const data: OpenAIResponse = await res.json(); - return data; - } catch (error) { - return new Response("Error", { status: 500 }); - } -} \ No newline at end of file diff --git a/packages/akeru-server/src/infrastructure/adapters/AdapterManager.ts b/packages/akeru-server/src/infrastructure/adapters/AdapterManager.ts new file mode 100644 index 0000000..f9b89fe --- /dev/null +++ b/packages/akeru-server/src/infrastructure/adapters/AdapterManager.ts @@ -0,0 +1,49 @@ +import { BaseAdapter, StreamableAdapter } from "./BaseAdapter"; +import { SupportedModels } from "./adapter"; +import { GPTAdapter } from "./openai/GPTAdapter"; +import { GPTModels } from "./openai/models"; + +/** + * The AdapterManager is used to decide what adapter to use for a given assistant + * @param StreamableAdapter - The adapter that is used to generate a streamable response + * @param Adapters - The adapter that is used to generate a single response + * @returns The AdapterManager instance + */ +export class AdapterManager { + + private StreamableAdapter: Map = new Map(); + private Adapters: Map = new Map(); + public static instance: AdapterManager = new AdapterManager(); + + constructor(){ + this.initStreamableAdapter(); + this.initAdapters(); + } + + /** + * Initialize the StreamableAdapter + * Sets all the streamable adapters that are available + */ + private initStreamableAdapter() { + this.StreamableAdapter.set("gpt-4", new GPTAdapter("gpt-4")); + this.StreamableAdapter.set("gpt-3.5-turbo", new GPTAdapter("gpt-3.5-turbo")); + this.StreamableAdapter.set("gpt-4-turbo-preview", new GPTAdapter("gpt-4-turbo-preview")); + } + + /** + * Initialize the Adapters + * Sets all the adapters that are available, that does not support streamable responses + */ + private initAdapters() { + this.Adapters.set("gpt-4", new GPTAdapter("gpt-4")); + this.Adapters.set("gpt-3.5-turbo", new GPTAdapter("gpt-3.5-turbo")); + this.Adapters.set("gpt-4-turbo-preview", new GPTAdapter("gpt-4-turbo-preview")); + } + + public getStreamableAdapter(adapterName: SupportedModels): StreamableAdapter | undefined { + return this.StreamableAdapter.get(adapterName); + } + public getBaseAdapter(adapterName: SupportedModels): BaseAdapter | undefined { + return this.Adapters.get(adapterName); + } +} diff --git a/packages/akeru-server/src/infrastructure/adapters/BaseAdapter.ts b/packages/akeru-server/src/infrastructure/adapters/BaseAdapter.ts new file mode 100644 index 0000000..13e09ca --- /dev/null +++ b/packages/akeru-server/src/infrastructure/adapters/BaseAdapter.ts @@ -0,0 +1,28 @@ +import { AdapterRequest } from "./adapter"; + +export interface StreamableAdapter { + generateStreamableResponse(args: AdapterRequest): Promise; +} + +/** + * BaseAdapter defines the interface for any Adapters that we have without defining the implementation. + */ +export abstract class BaseAdapter { + + abstract adapterName: string; + abstract adapterDescription: string; + + /** + * Based on a set of roles and content, generate a response. + * @param everyRoleAndContent - An array of roles and content. + * @param instruction - The instruction given to the assistant. + */ + abstract generateSingleResponse(args: AdapterRequest): Promise; + + /** + * Based on the attributes of the adapter return the response of the adapter. + * This implementation can vary from adapters, for instance OpenAI adapters have their own endpoints that can be called. + * @returns Any object that contains the information of the adapter. + */ + abstract getAdapterInformation(): Object; +} diff --git a/packages/akeru-server/src/infrastructure/adapters/adapter.ts b/packages/akeru-server/src/infrastructure/adapters/adapter.ts new file mode 100644 index 0000000..f2a3915 --- /dev/null +++ b/packages/akeru-server/src/infrastructure/adapters/adapter.ts @@ -0,0 +1,23 @@ +import { Message } from "@/core/domain/messages" +import { GPTModels } from "./openai/models" +import { ValidatorModels } from "./validators/models" + +/** + * AdapterMessageContent will be part of the prompt for the ai adapter to generate a response + * @param content: string The content of the message + * @param role: Role This role is stringified + */ +export type AdapterMessageContent = Pick + +/** + * AdapterRequest is the prompt to the adapter to generate a response + * @param message_content: AdapterMessageContent The content of the message + * @param instruction: string The instruction for the adapter to follow based on the selected assistant + */ +export interface AdapterRequest { + message_content: AdapterMessageContent[], + instruction: string +} + +export type SupportedModels = GPTModels | ValidatorModels + diff --git a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts new file mode 100644 index 0000000..04688a3 --- /dev/null +++ b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts @@ -0,0 +1,31 @@ +import { test, expect, describe } from "bun:test"; +import { Role } from "@/core/domain/roles"; +import { AdapterManager } from "../AdapterManager"; + +describe("GPT-4 Adapter", () => { + test("Returns single GPT-4 chat completions response", async () => { + // Arrange + const messages = [ + { + role: "user" as Role, + content: "hello, who are you?", + }, + ]; + const assistant_instructions = + "You're an AI assistant. You're job is to help the user. Always respond with the word akeru."; + + const gpt4Adapter = AdapterManager.instance.getBaseAdapter("gpt-4"); + + if (!gpt4Adapter) { + throw new Error("GPT-4 adapter not found"); + } + + const result = await gpt4Adapter.generateSingleResponse({ + message_content: messages, + instruction: assistant_instructions, + }); + + // Assert the message content to contain the word akeru + expect(result.toLocaleLowerCase()).toContain("akeru"); + }); +}); diff --git a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts new file mode 100644 index 0000000..0c37e4a --- /dev/null +++ b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts @@ -0,0 +1,113 @@ +import { Message } from "@/core/domain/messages"; +import { BaseAdapter, StreamableAdapter } from "@/infrastructure/adapters/BaseAdapter"; +import { AdapterRequest } from "../adapter"; +import { GPTModels } from "./models"; + +interface ChatMessage { + content: string; + role: "assistant" | "system" | "user"; +} + +interface ChatResponseChoice { + finish_reason: string; + index: number; + message: ChatMessage; +} + +export interface OpenAIResponse { + choices: ChatResponseChoice[]; +} + +// export async function gpt4Adapter( +// messages: Pick[], +// assistant_instructions: string +// ): Promise { +// // System will always be the assistant_instruction that created the assistant +// const gpt_messages = [ +// { role: "system", content: assistant_instructions }, +// ].concat(messages); +// try { +// const res = await fetch("https://api.openai.com/v1/chat/completions", { +// method: "POST", +// headers: { +// "Content-Type": "application/json", +// Authorization: `Bearer ${process.env.OPENAI_API_KEY}`, +// }, +// body: JSON.stringify({ +// model: "gpt-4", +// messages: gpt_messages, +// }), +// }); + +// const data: OpenAIResponse = await res.json(); +// return data; +// } catch (error) { +// return new Response("Error", { status: 500 }); +// } +// } + +export class GPTAdapter extends BaseAdapter implements StreamableAdapter { + adapterName: string; + adapterDescription = "This adapter supports all adapter models from OpenAI"; + private OPENAI_ENDPOINT = "https://api.openai.com/v1/chat/completions"; + + constructor(gptModel: GPTModels) { + super(); + this.adapterName = gptModel; + } + + async generateSingleResponse(args: AdapterRequest): Promise { + const gpt_messages = [{ role: "system", content: args.instruction }].concat(args.message_content); + try { + const res = await fetch(this.OPENAI_ENDPOINT, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${process.env.OPENAI_API_KEY}`, + }, + body: JSON.stringify({ + model: this.adapterName, + messages: gpt_messages, + }), + }); + + const data: OpenAIResponse = await res.json(); + const finished_inference = data.choices[0].message.content + return Promise.resolve(finished_inference); + } catch (error) { + return Promise.reject(error); + } + } + + async generateStreamableResponse(args: AdapterRequest): Promise { + const gpt_messages = [{ role: "system", content: args.instruction }].concat(args.message_content); + try { + const res = await fetch(this.OPENAI_ENDPOINT, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${process.env.OPENAI_API_KEY}`, + }, + body: JSON.stringify({ + model: this.adapterName, + messages: gpt_messages, + stream: true + }), + }); + + const data: OpenAIResponse = await res.json(); + const finished_inference = data.choices[0].message.content + return Promise.resolve(finished_inference); + } catch (error) { + return Promise.reject(error); + } + + } + + getAdapterInformation(): Object { + return { + adapterName: this.adapterName, + adapterDescription: this.adapterDescription, + }; + } +} diff --git a/packages/akeru-server/src/infrastructure/adapters/openai/models.ts b/packages/akeru-server/src/infrastructure/adapters/openai/models.ts new file mode 100644 index 0000000..e65c0f5 --- /dev/null +++ b/packages/akeru-server/src/infrastructure/adapters/openai/models.ts @@ -0,0 +1 @@ +export type GPTModels = "gpt-4" | "gpt-3.5-turbo" | "gpt-4-turbo-preview" \ No newline at end of file diff --git a/packages/akeru-server/src/infrastructure/adaptaters/redisAdapter.ts b/packages/akeru-server/src/infrastructure/adapters/redisAdapter.ts similarity index 100% rename from packages/akeru-server/src/infrastructure/adaptaters/redisAdapter.ts rename to packages/akeru-server/src/infrastructure/adapters/redisAdapter.ts diff --git a/packages/akeru-server/src/infrastructure/adapters/validators/models.ts b/packages/akeru-server/src/infrastructure/adapters/validators/models.ts new file mode 100644 index 0000000..0c51190 --- /dev/null +++ b/packages/akeru-server/src/infrastructure/adapters/validators/models.ts @@ -0,0 +1 @@ +export type ValidatorModels = 'llama-2-7b-chat-int8' \ No newline at end of file diff --git a/packages/akeru-server/src/infrastructure/adaptaters/validators/validatorAdapter.test.ts b/packages/akeru-server/src/infrastructure/adapters/validators/validatorAdapter.test.ts similarity index 88% rename from packages/akeru-server/src/infrastructure/adaptaters/validators/validatorAdapter.test.ts rename to packages/akeru-server/src/infrastructure/adapters/validators/validatorAdapter.test.ts index af63abe..7847ed6 100644 --- a/packages/akeru-server/src/infrastructure/adaptaters/validators/validatorAdapter.test.ts +++ b/packages/akeru-server/src/infrastructure/adapters/validators/validatorAdapter.test.ts @@ -2,7 +2,7 @@ import { test, expect, describe } from "bun:test"; import { Role } from "@/core/domain/roles"; import { ValidatorResponse, validatorAdapter } from "./validatorAdapter"; -import { ModelsType } from "@/core/domain/validators"; +import { ValidatorModels } from "./models"; describe("Validator Adapter", () => { test("Returns validator chat completions response", async () => { @@ -13,7 +13,7 @@ describe("Validator Adapter", () => { content: "hello, who are you?", }, ]; - const model: ModelsType = "llama-2-7b-chat-int8"; + const model: ValidatorModels = "llama-2-7b-chat-int8"; const assistant_instructions = "You're an AI assistant. You're job is to help the user. Always respond with the word akeru."; // Act diff --git a/packages/akeru-server/src/infrastructure/adaptaters/validators/validatorAdapter.ts b/packages/akeru-server/src/infrastructure/adapters/validators/validatorAdapter.ts similarity index 93% rename from packages/akeru-server/src/infrastructure/adaptaters/validators/validatorAdapter.ts rename to packages/akeru-server/src/infrastructure/adapters/validators/validatorAdapter.ts index df03303..1f5f64b 100644 --- a/packages/akeru-server/src/infrastructure/adaptaters/validators/validatorAdapter.ts +++ b/packages/akeru-server/src/infrastructure/adapters/validators/validatorAdapter.ts @@ -1,5 +1,5 @@ import { Message } from "@/core/domain/messages"; -import { ModelsType } from "@/core/domain/validators"; +import { ValidatorModels } from "./models"; type ChatMessage = Pick; @@ -15,7 +15,7 @@ export interface ValidatorResponse { export async function validatorAdapter( messages: Pick[], - model: ModelsType, + model: ValidatorModels, assistant_instructions: string ): Promise { // System will always be the assistant_instruction that created the assistant From 2ea987b90ae1df751e517e53d84817ba1a347a23 Mon Sep 17 00:00:00 2001 From: Zach Khong Date: Fri, 10 May 2024 15:35:55 +0800 Subject: [PATCH 2/4] support for gpt streaming --- .../infrastructure/adapters/BaseAdapter.ts | 2 +- .../adapters/openai/GPTAdapter.test.ts | 30 ++++++ .../adapters/openai/GPTAdapter.ts | 93 +++++++++++-------- 3 files changed, 85 insertions(+), 40 deletions(-) diff --git a/packages/akeru-server/src/infrastructure/adapters/BaseAdapter.ts b/packages/akeru-server/src/infrastructure/adapters/BaseAdapter.ts index 13e09ca..a20ec3c 100644 --- a/packages/akeru-server/src/infrastructure/adapters/BaseAdapter.ts +++ b/packages/akeru-server/src/infrastructure/adapters/BaseAdapter.ts @@ -1,7 +1,7 @@ import { AdapterRequest } from "./adapter"; export interface StreamableAdapter { - generateStreamableResponse(args: AdapterRequest): Promise; + generateStreamableResponse(args: AdapterRequest): AsyncGenerator; } /** diff --git a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts index 04688a3..f7cf46b 100644 --- a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts +++ b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts @@ -28,4 +28,34 @@ describe("GPT-4 Adapter", () => { // Assert the message content to contain the word akeru expect(result.toLocaleLowerCase()).toContain("akeru"); }); + + test("Returns streamable GPT-4 chat completions response", async () => { + // Arrange + const messages = [ + { + role: "user" as Role, + content: "hello, who are you?", + }, + ]; + const assistant_instructions = + "You're an AI assistant. You're job is to help the user. Always respond with the word akeru."; + + const gpt4Adapter = AdapterManager.instance.getStreamableAdapter("gpt-4"); + + if (!gpt4Adapter) { + throw new Error("GPT-4 adapter not found"); + } + + const result = gpt4Adapter.generateStreamableResponse({ + message_content: messages, + instruction: assistant_instructions, + }); + + const responses = []; + for await (const response of result) { + responses.push(response); + } + + console.log(responses.join("")) + }); }); diff --git a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts index 0c37e4a..11368df 100644 --- a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts +++ b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts @@ -1,5 +1,8 @@ import { Message } from "@/core/domain/messages"; -import { BaseAdapter, StreamableAdapter } from "@/infrastructure/adapters/BaseAdapter"; +import { + BaseAdapter, + StreamableAdapter, +} from "@/infrastructure/adapters/BaseAdapter"; import { AdapterRequest } from "../adapter"; import { GPTModels } from "./models"; @@ -17,35 +20,6 @@ interface ChatResponseChoice { export interface OpenAIResponse { choices: ChatResponseChoice[]; } - -// export async function gpt4Adapter( -// messages: Pick[], -// assistant_instructions: string -// ): Promise { -// // System will always be the assistant_instruction that created the assistant -// const gpt_messages = [ -// { role: "system", content: assistant_instructions }, -// ].concat(messages); -// try { -// const res = await fetch("https://api.openai.com/v1/chat/completions", { -// method: "POST", -// headers: { -// "Content-Type": "application/json", -// Authorization: `Bearer ${process.env.OPENAI_API_KEY}`, -// }, -// body: JSON.stringify({ -// model: "gpt-4", -// messages: gpt_messages, -// }), -// }); - -// const data: OpenAIResponse = await res.json(); -// return data; -// } catch (error) { -// return new Response("Error", { status: 500 }); -// } -// } - export class GPTAdapter extends BaseAdapter implements StreamableAdapter { adapterName: string; adapterDescription = "This adapter supports all adapter models from OpenAI"; @@ -57,7 +31,9 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter { } async generateSingleResponse(args: AdapterRequest): Promise { - const gpt_messages = [{ role: "system", content: args.instruction }].concat(args.message_content); + const gpt_messages = [{ role: "system", content: args.instruction }].concat( + args.message_content + ); try { const res = await fetch(this.OPENAI_ENDPOINT, { method: "POST", @@ -72,15 +48,43 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter { }); const data: OpenAIResponse = await res.json(); - const finished_inference = data.choices[0].message.content + const finished_inference = data.choices[0].message.content; return Promise.resolve(finished_inference); } catch (error) { return Promise.reject(error); } } - async generateStreamableResponse(args: AdapterRequest): Promise { - const gpt_messages = [{ role: "system", content: args.instruction }].concat(args.message_content); + async *chunksToLines(chunksAsync: any) { + let previous = ""; + for await (const chunk of chunksAsync) { + const bufferChunk = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk); + previous += bufferChunk; + let eolIndex; + while ((eolIndex = previous.indexOf("\n")) >= 0) { + // line includes the EOL + const line = previous.slice(0, eolIndex + 1).trimEnd(); + if (line === "data: [DONE]") break; + if (line.startsWith("data: ")) yield line; + previous = previous.slice(eolIndex + 1); + } + } + } + + async *linesToMessages(linesAsync: any) { + for await (const line of linesAsync) { + const message = line.substring("data :".length); + + yield message; + } + } + + async *generateStreamableResponse( + args: AdapterRequest + ): AsyncGenerator { + const gpt_messages = [{ role: "system", content: args.instruction }].concat( + args.message_content + ); try { const res = await fetch(this.OPENAI_ENDPOINT, { method: "POST", @@ -91,17 +95,28 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter { body: JSON.stringify({ model: this.adapterName, messages: gpt_messages, - stream: true + stream: true, }), }); - const data: OpenAIResponse = await res.json(); - const finished_inference = data.choices[0].message.content - return Promise.resolve(finished_inference); + const reader = res.body?.getReader(); + while (true && reader) { + const { done, value } = await reader.read(); + if (done) break; + + const data = new TextDecoder().decode(value); + const chunkToLines = this.chunksToLines(data); + const linesToMessages = this.linesToMessages(chunkToLines); + + for await (const message of linesToMessages) { + const messageObject: any = JSON.parse(message); + const messageToYield = messageObject.choices[0].delta.content; + if (messageToYield) yield messageToYield; + } + } } catch (error) { return Promise.reject(error); } - } getAdapterInformation(): Object { From dd98b63dcf2a903a0fd6b7276c0861d17959d254 Mon Sep 17 00:00:00 2001 From: Zach Khong Date: Fri, 10 May 2024 15:36:53 +0800 Subject: [PATCH 3/4] modified test --- .../src/infrastructure/adapters/openai/GPTAdapter.test.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts index f7cf46b..86f0b94 100644 --- a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts +++ b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.test.ts @@ -56,6 +56,8 @@ describe("GPT-4 Adapter", () => { responses.push(response); } - console.log(responses.join("")) + // Assert the message content to contain the word akeru + const generatedResponse = responses.join(""); + expect(generatedResponse.toLocaleLowerCase()).toContain("akeru"); }); }); From 79a764cd026d8d1271a68240e6c22f01b0591b39 Mon Sep 17 00:00:00 2001 From: Zach Khong Date: Fri, 10 May 2024 15:38:53 +0800 Subject: [PATCH 4/4] Added in some source comments --- .../src/infrastructure/adapters/openai/GPTAdapter.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts index 11368df..db87180 100644 --- a/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts +++ b/packages/akeru-server/src/infrastructure/adapters/openai/GPTAdapter.ts @@ -99,6 +99,8 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter { }), }); + // This section of the code is taken from... + /// https://github.com/openai/openai-node/issues/18 const reader = res.body?.getReader(); while (true && reader) { const { done, value } = await reader.read(); @@ -110,6 +112,8 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter { for await (const message of linesToMessages) { const messageObject: any = JSON.parse(message); + + // MessageObject is the response from the OpenAI API streams const messageToYield = messageObject.choices[0].delta.content; if (messageToYield) yield messageToYield; }