Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

feat: add validator adapter #97

Merged
merged 14 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/akeru-frontend/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export default function Home() {
<main className="min-h-screen">
<HeroSection />


<FormSection />

<section className="flex flex-col mt-10 md:mt-36 w-full gap-4 md:flex-row md:gap-10 mx-auto">
Expand Down
2 changes: 1 addition & 1 deletion packages/akeru-server/scripts/createSuperAdmin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {

async function createSuperAdminWithApiKey() {
const userData = {
// Add any necessary user data here
};

const userId = await createSuperAdmin(userData);
Expand All @@ -15,6 +14,7 @@ async function createSuperAdminWithApiKey() {
}

const apiKey = await createApiToken(userId);

if (!apiKey) {
console.error("Failed to create API key for super admin user");
return;
Expand Down
58 changes: 36 additions & 22 deletions packages/akeru-server/src/core/application/services/runService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import { ThreadRun, ThreadRunRequest } from "@/core/domain/run";
import { getAssistantData } from "./assistantService";
import { getThread } from "./threadService";
import { createMessage, getAllMessage } from "./messageService";
import { gpt4Adapter } from "@/infrastructure/adaptaters/openai/gpt4Adapter";
import { getAllMessage } from "./messageService";

import { Role } from "@/core/domain/roles";
import { v4 as uuidv4 } from "uuid";

import {
AdapterManager,
GPT4AdapterManager,
ValidatorAdapterManager,
} from "@/infrastructure/adaptaters/adaptermanager";

export async function runAssistantWithThread(runData: ThreadRunRequest) {
// get all messages from the thread, and run it over to the assistant to get a response
Expand All @@ -17,7 +22,8 @@ export async function runAssistantWithThread(runData: ThreadRunRequest) {
]);

// If no thread data or assistant data, an error should be thrown as we need both to run a thread
if (!threadData || !assistantData) throw new Error("No thread or assistant found.");
if (!threadData || !assistantData)
throw new Error("No thread or assistant found.");

const everyMessage = await getAllMessage(threadData.id);
// only get role and content from every message for context.
Expand All @@ -30,25 +36,33 @@ export async function runAssistantWithThread(runData: ThreadRunRequest) {
};
});

let adapterManager: AdapterManager;

// Calls the appropriate adapter based on what model the assistant uses
if (assistantData.model === "gpt-4") {
const gpt4AdapterRes: any = await gpt4Adapter(
everyRoleAndContent,
assistantData.instruction
);

const assistantResponse: string = gpt4AdapterRes.choices[0].message.content;

// add assistant response to the thread
await createMessage(assistant_id, thread_id, assistantResponse);

const threadRunResponse: ThreadRun = {
id: uuidv4(),
assistant_id: assistant_id,
thread_id: thread_id,
created_at: new Date(),
};

return threadRunResponse;
adapterManager = new GPT4AdapterManager();
} else if (assistantData.model === "llama-2-7b-chat-int8") {
adapterManager = new ValidatorAdapterManager("llama-2-7b-chat-int8");
} else {
throw new Error("Unsupported assistant model");
}

// Generate response using the appropriate adapter
const assistantResponse: string = await adapterManager.generateResponse(
everyRoleAndContent,
assistantData.instruction
);

// Add assistant response to the thread
await adapterManager.createUserMessage(
assistant_id,
thread_id,
assistantResponse
);

// Create thread run response
const threadRunResponse: ThreadRun =
await adapterManager.createThreadRunResponse(assistant_id, thread_id);

return threadRunResponse;
}
6 changes: 4 additions & 2 deletions packages/akeru-server/src/core/domain/assistant.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import { ModelsType } from "./validators";

export type Assistant = {
id: string;
name: string;
model: "gpt-4";
model: ModelsType;
tools: { type: string }[];
fileIds: string[];
instruction: string
instruction: string;
};
2 changes: 1 addition & 1 deletion packages/akeru-server/src/core/domain/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ export interface ThreadRun {
thread_id: Thread["id"];
}

export type ThreadRunRequest = Pick<ThreadRun, "thread_id" | "assistant_id">
export type ThreadRunRequest = Pick<ThreadRun, "thread_id" | "assistant_id">;
7 changes: 7 additions & 0 deletions packages/akeru-server/src/core/domain/validators.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// types for all the models we'll have on the subnet
// chaining like this: "BtModels & Models" gives a type of never

type BtModels = "bt-gpt-4";
type Models = "gpt-4" | "gpt-3.5" | "llama-2-7b-chat-int8";

export type ModelsType = BtModels | Models;
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import { ThreadRun } from "@/core/domain/run";
import { v4 as uuidv4 } from "uuid";
import { gpt4Adapter } from "./openai/gpt4Adapter";
import { validatorAdapter } from "./validators/validatorAdapter";

import { createMessage } from "@/core/application/services/messageService";
import { ModelsType } from "@/core/domain/validators";

export abstract class AdapterManager {
abstract generateResponse(
everyRoleAndContent: { role: string; content: string }[],
instruction: any
): Promise<string>;

async createUserMessage(
assistant_id: string,
thread_id: string,
assistantResponse: string
): Promise<void> {
await createMessage(assistant_id, thread_id, assistantResponse);
}

async createThreadRunResponse(
assistant_id: string,
thread_id: string
): Promise<ThreadRun> {
const threadRunResponse: ThreadRun = {
id: uuidv4(),
assistant_id: assistant_id,
thread_id: thread_id,
created_at: new Date(),
};
return threadRunResponse;
}
}

export class GPT4AdapterManager extends AdapterManager {
async generateResponse(
everyRoleAndContent: any,
instruction: any
): Promise<string> {
const gpt4AdapterRes: any = await gpt4Adapter(
everyRoleAndContent,
instruction
);
return gpt4AdapterRes.choices[0].message.content;
}
}

export class ValidatorAdapterManager extends AdapterManager {
private modelName: ModelsType;

constructor(modelName: ModelsType) {
super();
this.modelName = modelName;
}

async generateResponse(
everyRoleAndContent: any,
instruction: any
): Promise<string> {
const validatorAdapterRes: any = await validatorAdapter(
everyRoleAndContent,
this.modelName,
instruction
);
return validatorAdapterRes.choices[0].message.content;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ describe("GPT-4 Adapter", () => {
);
});
});


Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { test, expect, describe } from "bun:test";

import { Role } from "@/core/domain/roles";
import { ValidatorResponse, validatorAdapter } from "./validatorAdapter";
import { ModelsType } from "@/core/domain/validators";

describe("Validator Adapter", () => {
test("Returns validator chat completions response", async () => {
// Arrange
const messages = [
{
role: "user" as Role,
content: "hello, who are you?",
},
];
const model: ModelsType = "llama-2-7b-chat-int8";
const assistant_instructions =
"You're an AI assistant. You're job is to help the user. Always respond with the word akeru.";
// Act
const result = (await validatorAdapter(
messages,
model,
assistant_instructions
)) as ValidatorResponse;

// Assert the message content to contain the word akeru
expect(result.choices[0].message.content.toLocaleLowerCase()).toContain(
"akeru"
);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { Message } from "@/core/domain/messages";
import { ModelsType } from "@/core/domain/validators";

type ChatMessage = Pick<Message, "role" | "content">;

interface ChatResponseChoice {
finish_reason: string;
index: number;
message: ChatMessage;
}

export interface ValidatorResponse {
choices: ChatResponseChoice[];
}

export async function validatorAdapter(
messages: Pick<Message, "role" | "content">[],
model: ModelsType,
assistant_instructions: string
): Promise<unknown> {
// System will always be the assistant_instruction that created the assistant
const validator_messages = [
{ role: "system", content: assistant_instructions },
].concat(messages);
try {
const res = await fetch("https://akeru-validator.onrender.com/chat", {
method: "POST",
headers: {
"Content-Type": "application/json",
// Authorization: `Bearer ${process.env.OPENAI_API_KEY}`,
},
body: JSON.stringify({
model: model,
messages: validator_messages,
}),
});

const data: ValidatorResponse = await res.json();
return data;
} catch (error) {
return new Response("Error", { status: 500 });
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ export const redisConfig = {
? { password: process.env.REDIS_PASSWORD }
: {}), // Redis password
db: process.env.REDIS_DB ? parseInt(process.env.REDIS_DB) : 0, // Redis DB
tls: process.env.NODE_ENV === "production" ? {} : undefined,
tls: {},
};
Loading