Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
clean up for adapters for both streamable adapters and nonstreamable …
Browse files Browse the repository at this point in the history
…ones
  • Loading branch information
multipletwigs committed May 9, 2024
1 parent 34e20a9 commit 053efd7
Show file tree
Hide file tree
Showing 26 changed files with 282 additions and 189 deletions.
1 change: 1 addition & 0 deletions packages/akeru-server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
},
"dependencies": {
"@elysiajs/bearer": "^1.0.2",
"@elysiajs/stream": "^1.0.2",
"@elysiajs/swagger": "^1.0.4",
"dayjs": "^1.11.10",
"elysia": "latest",
Expand Down
2 changes: 1 addition & 1 deletion packages/akeru-server/scripts/cleanDB.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { redis } from "@/infrastructure/adaptaters/redisAdapter";
import { redis } from "@/infrastructure/adapters/redisAdapter";

const script = async () => {
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { createSuperAdminForTesting } from "@/__tests__/utils";
import { app } from "@/index";
import { redis } from "@/infrastructure/adaptaters/redisAdapter";
import { redis } from "@/infrastructure/adapters/redisAdapter";
import { test, expect, describe } from "bun:test";
import { UNAUTHORIZED_MISSING_TOKEN } from "../../ports/returnValues";

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { assignRole, createUser } from "core/application/services/userService";
import { createToken } from "@/core/application/services/tokenService";
import { redis } from "@/infrastructure/adaptaters/redisAdapter";
import { redis } from "@/infrastructure/adapters/redisAdapter";
import { ulid } from "ulid";

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { Assistant } from "@/core/domain/assistant";
import { redis } from "@/infrastructure/adaptaters/redisAdapter";
import { redis } from "@/infrastructure/adapters/redisAdapter";

/**
* Creates an assistant in Redis if it does not exist and adds a 'CREATED_BY' relationship to the user.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { v4 as uuidv4 } from "uuid";
import { redis } from "@/infrastructure/adaptaters/redisAdapter";
import { Message } from "@/core/domain/messages";
import dayjs from "dayjs";
import { getUserRole } from "./userService";
import { redis } from "@/infrastructure/adapters/redisAdapter";

export async function createMessage(
userId: string,
Expand Down
46 changes: 19 additions & 27 deletions packages/akeru-server/src/core/application/services/runService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
import { ThreadRun, ThreadRunRequest } from "@/core/domain/run";
import { getAssistantData } from "./assistantService";
import { getThread } from "./threadService";
import { getAllMessage } from "./messageService";
import { createMessage, getAllMessage } from "./messageService";

import { Role } from "@/core/domain/roles";
import { AdapterManager } from "@/infrastructure/adapters/AdapterManager";

import {
AdapterManager,
GPT4AdapterManager,
ValidatorAdapterManager,
} from "@/infrastructure/adaptaters/adaptermanager";
import { v4 as uuidv4 } from "uuid";

export async function runAssistantWithThread(runData: ThreadRunRequest) {
// get all messages from the thread, and run it over to the assistant to get a response
Expand All @@ -36,33 +33,28 @@ export async function runAssistantWithThread(runData: ThreadRunRequest) {
};
});

let adapterManager: AdapterManager;
const adapter = AdapterManager.instance.getBaseAdapter(assistantData.model);

// Calls the appropriate adapter based on what model the assistant uses
if (assistantData.model === "gpt-4") {
adapterManager = new GPT4AdapterManager();
} else if (assistantData.model === "llama-2-7b-chat-int8") {
adapterManager = new ValidatorAdapterManager("llama-2-7b-chat-int8");
} else {
throw new Error("Unsupported assistant model");
if (!adapter) {
throw new Error("Adapter not found");
}

// Generate response using the appropriate adapter
const assistantResponse: string = await adapterManager.generateResponse(
everyRoleAndContent,
assistantData.instruction
);
const assistantResponse: string = await adapter.generateSingleResponse({
message_content: everyRoleAndContent,
instruction: assistantData.instruction,
});

// Add assistant response to the thread
await adapterManager.createUserMessage(
assistant_id,
thread_id,
assistantResponse
);

// Create thread run response
const threadRunResponse: ThreadRun =
await adapterManager.createThreadRunResponse(assistant_id, thread_id);
await createMessage(assistant_id, thread_id, assistantResponse);

// Return the thread run response
const threadRunResponse: ThreadRun = {
assistant_id: assistant_id,
thread_id: thread_id,
id: uuidv4(),
created_at: new Date(),
};

return threadRunResponse;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Thread } from "@/core/domain/thread";
import { redis } from "@/infrastructure/adaptaters/redisAdapter";
import { redis } from "@/infrastructure/adapters/redisAdapter";

/**
* Creates a new thread in Redis.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { redis } from "@/infrastructure/adaptaters/redisAdapter";
import { redis } from "@/infrastructure/adapters/redisAdapter";
import { getUserPermissions } from "@/core/application/services/userService";
import jwt from "jsonwebtoken";
import { PermissionDetailArray } from "@/core/domain/permissions";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import { Role, getRolePermissions } from "@/core/domain/roles";
import { HumanUserBody, User, UserRole } from "@/core/domain/user";
import { redis } from "@/infrastructure/adaptaters/redisAdapter";
import { redis } from "@/infrastructure/adapters/redisAdapter";
import { ulid } from "ulid";

/**
Expand Down
4 changes: 2 additions & 2 deletions packages/akeru-server/src/core/domain/assistant.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { ModelsType } from "./validators";
import { SupportedModels } from "@/infrastructure/adapters/adapter";

export type Assistant = {
id: string;
name: string;
model: ModelsType;
model: SupportedModels;
tools: { type: string }[];
fileIds: string[];
instruction: string;
Expand Down
3 changes: 2 additions & 1 deletion packages/akeru-server/src/core/domain/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export interface ThreadRun {
created_at: Date;
assistant_id: Assistant["id"];
thread_id: Thread["id"];
stream?: boolean
}

export type ThreadRunRequest = Pick<ThreadRun, "thread_id" | "assistant_id">;
export type ThreadRunRequest = Pick<ThreadRun, "thread_id" | "assistant_id" | 'stream'>;
7 changes: 0 additions & 7 deletions packages/akeru-server/src/core/domain/validators.ts

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { BaseAdapter, StreamableAdapter } from "./BaseAdapter";
import { SupportedModels } from "./adapter";
import { GPTAdapter } from "./openai/GPTAdapter";
import { GPTModels } from "./openai/models";

/**
* The AdapterManager is used to decide what adapter to use for a given assistant
* @param StreamableAdapter - The adapter that is used to generate a streamable response
* @param Adapters - The adapter that is used to generate a single response
* @returns The AdapterManager instance
*/
export class AdapterManager {

private StreamableAdapter: Map<SupportedModels, StreamableAdapter> = new Map();
private Adapters: Map<SupportedModels, BaseAdapter> = new Map();
public static instance: AdapterManager = new AdapterManager();

constructor(){
this.initStreamableAdapter();
this.initAdapters();
}

/**
* Initialize the StreamableAdapter
* Sets all the streamable adapters that are available
*/
private initStreamableAdapter() {
this.StreamableAdapter.set("gpt-4", new GPTAdapter("gpt-4"));
this.StreamableAdapter.set("gpt-3.5-turbo", new GPTAdapter("gpt-3.5-turbo"));
this.StreamableAdapter.set("gpt-4-turbo-preview", new GPTAdapter("gpt-4-turbo-preview"));
}

/**
* Initialize the Adapters
* Sets all the adapters that are available, that does not support streamable responses
*/
private initAdapters() {
this.Adapters.set("gpt-4", new GPTAdapter("gpt-4"));
this.Adapters.set("gpt-3.5-turbo", new GPTAdapter("gpt-3.5-turbo"));
this.Adapters.set("gpt-4-turbo-preview", new GPTAdapter("gpt-4-turbo-preview"));
}

public getStreamableAdapter(adapterName: SupportedModels): StreamableAdapter | undefined {
return this.StreamableAdapter.get(adapterName);
}
public getBaseAdapter(adapterName: SupportedModels): BaseAdapter | undefined {
return this.Adapters.get(adapterName);
}
}
28 changes: 28 additions & 0 deletions packages/akeru-server/src/infrastructure/adapters/BaseAdapter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { AdapterRequest } from "./adapter";

export interface StreamableAdapter {
generateStreamableResponse(args: AdapterRequest): Promise<string>;
}

/**
* BaseAdapter defines the interface for any Adapters that we have without defining the implementation.
*/
export abstract class BaseAdapter {

abstract adapterName: string;
abstract adapterDescription: string;

/**
* Based on a set of roles and content, generate a response.
* @param everyRoleAndContent - An array of roles and content.
* @param instruction - The instruction given to the assistant.
*/
abstract generateSingleResponse(args: AdapterRequest): Promise<string>;

/**
* Based on the attributes of the adapter return the response of the adapter.
* This implementation can vary from adapters, for instance OpenAI adapters have their own endpoints that can be called.
* @returns Any object that contains the information of the adapter.
*/
abstract getAdapterInformation(): Object;
}
Loading

0 comments on commit 053efd7

Please sign in to comment.