Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

feat: add validator adapter #97

Merged
merged 14 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/akeru-frontend/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export default function Home() {
<main className="min-h-screen">
<HeroSection />


<FormSection />

<section className="flex flex-col mt-10 md:mt-36 w-full gap-4 md:flex-row md:gap-10 mx-auto">
Expand Down
1 change: 0 additions & 1 deletion packages/akeru-server/scripts/createSuperAdmin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {

async function createSuperAdminWithApiKey() {
const userData = {
// Add any necessary user data here
};

const userId = await createSuperAdmin(userData);
Expand Down
29 changes: 27 additions & 2 deletions packages/akeru-server/src/core/application/services/runService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { createMessage, getAllMessage } from "./messageService";
import { gpt4Adapter } from "@/infrastructure/adaptaters/openai/gpt4Adapter";
import { Role } from "@/core/domain/roles";
import { v4 as uuidv4 } from "uuid";
import { validatorAdapter } from "@/infrastructure/adaptaters/validators/validatorAdapter";

export async function runAssistantWithThread(runData: ThreadRunRequest) {
// get all messages from the thread, and run it over to the assistant to get a response
Expand All @@ -17,7 +18,8 @@ export async function runAssistantWithThread(runData: ThreadRunRequest) {
]);

// If no thread data or assistant data, an error should be thrown as we need both to run a thread
if (!threadData || !assistantData) throw new Error("No thread or assistant found.");
if (!threadData || !assistantData)
throw new Error("No thread or assistant found.");

const everyMessage = await getAllMessage(threadData.id);
// only get role and content from every message for context.
Expand All @@ -38,7 +40,30 @@ export async function runAssistantWithThread(runData: ThreadRunRequest) {
);

const assistantResponse: string = gpt4AdapterRes.choices[0].message.content;


// add assistant response to the thread
await createMessage(assistant_id, thread_id, assistantResponse);

const threadRunResponse: ThreadRun = {
id: uuidv4(),
assistant_id: assistant_id,
thread_id: thread_id,
created_at: new Date(),
};

return threadRunResponse;
}

if (assistantData.model === "llama-2-7b-chat-int8") {
const validatorAdapterRes: any = await validatorAdapter(
everyRoleAndContent,
"llama-2-7b-chat-int8",
assistantData.instruction
);

const assistantResponse: string =
validatorAdapterRes.choices[0].message.content;

// add assistant response to the thread
await createMessage(assistant_id, thread_id, assistantResponse);

Expand Down
6 changes: 4 additions & 2 deletions packages/akeru-server/src/core/domain/assistant.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import { Validator } from "./validators";

export type Assistant = {
id: string;
name: string;
model: "gpt-4";
model: Validator;
tools: { type: string }[];
fileIds: string[];
instruction: string
instruction: string;
};
1 change: 1 addition & 0 deletions packages/akeru-server/src/core/domain/validators.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export type Validator = "llama-2-7b-chat-int8" | "whisper" | "gpt-4"; // validators / LLM models we have
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ describe("GPT-4 Adapter", () => {
);
});
});


Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { test, expect, describe } from "bun:test";

import { Role } from "@/core/domain/roles";
import { ValidatorResponse, validatorAdapter } from "./validatorAdapter";

describe("Validator Adapter", () => {
test("Returns validator chat completions response", async () => {
// Arrange
const messages = [
{
role: "user" as Role,
content: "hello, who are you?",
},
];
const model = "llama-2-7b-chat-int8";
const assistant_instructions =
"You're an AI assistant. You're job is to help the user. Always respond with the word akeru.";
// Act
const result = (await validatorAdapter(
messages,
model,
assistant_instructions
)) as ValidatorResponse;

// Assert the message content to contain the word akeru
expect(result.choices[0].message.content.toLocaleLowerCase()).toContain(
"akeru"
);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { Message } from "@/core/domain/messages";
import { Validator } from "@/core/domain/validators";

interface ChatMessage {
content: string;
role: "assistant" | "system" | "user";
}

interface ChatResponseChoice {
finish_reason: string;
index: number;
message: ChatMessage;
}

export interface ValidatorResponse {
choices: ChatResponseChoice[];
}


export async function validatorAdapter(
messages: Pick<Message, "role" | "content">[],
model: Validator,
assistant_instructions: string
): Promise<unknown> {
// System will always be the assistant_instruction that created the assistant
const validator_messages = [
{ role: "system", content: assistant_instructions },
].concat(messages);
try {
const res = await fetch("https://akeru-validator.onrender.com/chat", {
method: "POST",
headers: {
"Content-Type": "application/json",
// Authorization: `Bearer ${process.env.OPENAI_API_KEY}`,
},
body: JSON.stringify({
model: model,
messages: validator_messages,
}),
});

const data: ValidatorResponse = await res.json();
return data;
} catch (error) {
return new Response("Error", { status: 500 });
}
}
Loading