Skip to content

Commit

Permalink
style updated
Browse files Browse the repository at this point in the history
  • Loading branch information
arun477 committed Dec 24, 2024
1 parent 0aa6718 commit 1fc890a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
18 changes: 13 additions & 5 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import logging
import contextlib
import db as _db
import numpy as np
from tqdm import tqdm

from rich.panel import Panel
from rich.table import Table
Expand Down Expand Up @@ -43,8 +45,9 @@ def suppress_outputs():


# model infos
MODEL_NAME = "answerdotai/ModernBERT-base"
MODEL_NAME = "all-mpnet-base-v2"
MODEL_NAME = "answerdotai/ModernBERT-large"
# MODEL_NAME = "answerdotai/ModernBERT-base"
# MODEL_NAME = "all-mpnet-base-v2"
SUIK_LOGO = "🦦"


Expand All @@ -62,7 +65,7 @@ def load_model(model_name):


def load_meta(meta_file):
return _db.fetch_all_documents()
# return _db.fetch_all_documents()
with open(meta_file, "r") as f:
return json.loads(f.read())

Expand All @@ -72,7 +75,12 @@ def embed(items, model):


def load_meta_emb(meta, model):
return embed([doc["description"] for doc in meta], model)
console.print(Panel.fit(f"{SUIK_LOGO} Setup...", title="Embedding"))
# for some reason modernbert throws nan value if we pass the batch of items
embs = []
for doc in tqdm(meta):
embs.append(embed(doc["description"], model))
return np.vstack(embs)


def get_meta_match(meta_emb, q_emb, model):
Expand Down Expand Up @@ -129,7 +137,7 @@ def main():
meta_emb = load_meta_emb(meta, model)
console.print(Panel.fit("✨ Ask me any linux commands (type 'exit' to quit)", title=f"{SUIK_LOGO} Suika Ready"))
while True:
question = Prompt.ask(f"[#B7D46F]Your question {SUIK_LOGO} ")
question = Prompt.ask(f"[#B7D46F]Your question {SUIK_LOGO} ")
if question.lower() == "exit":
console.print(f"{SUIK_LOGO} 👋 Goodbye!")
break
Expand Down
1 change: 0 additions & 1 deletion dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@
"cell_type": "markdown",
"id": "07e4dd5c-cd70-4571-86a7-1b8a4deb8418",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"scrolled": true
},
"source": [
Expand Down

0 comments on commit 1fc890a

Please sign in to comment.