Skip to content

Commit

Permalink
linux command added to db
Browse files Browse the repository at this point in the history
  • Loading branch information
arun477 committed Dec 24, 2024
1 parent d268ac7 commit cd9fe2d
Show file tree
Hide file tree
Showing 7 changed files with 24,737 additions and 44 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.ipynb_checkpoints
__pycache__
tldr
132 changes: 89 additions & 43 deletions agent.py
Original file line number Diff line number Diff line change
@@ -1,113 +1,159 @@
import typer
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer, util
import os
import torch
import torch._dynamo
import json
from rich import print
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.markdown import Markdown
import warnings
import logging
import contextlib
from PIL import Image
import db as _db

from rich.panel import Panel
from rich.table import Table
from rich.markdown import Markdown
from rich.console import Group
from rich.syntax import Syntax
from rich.style import Style
from rich.text import Text
from rich.prompt import Prompt


warnings.filterwarnings('ignore')
# supress warnings
warnings.filterwarnings("ignore")
logging.getLogger().setLevel(logging.CRITICAL)
torch._dynamo.config.suppress_errors = True
logging.getLogger("torch").setLevel(logging.CRITICAL)
logging.getLogger("torch._dynamo").setLevel(logging.CRITICAL)
logging.getLogger("triton").setLevel(logging.CRITICAL)

MODEL_NAME = "answerdotai/ModernBERT-base"
META_FILE = 'meta.json'
console = Console()

@contextlib.contextmanager
def suppress_outputs():
"""Context manager to suppress all outputs, warnings, and errors temporarily"""
with open(os.devnull, 'w') as devnull:
with open(os.devnull, "w") as devnull:
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
yield


# model infos
MODEL_NAME = "answerdotai/ModernBERT-base"
MODEL_NAME = "all-mpnet-base-v2"
SUIK_LOGO = "🦦"


# agent config
META_FILE = "meta.json"
console = Console()
LABEL_COLORS = Style(color="#B7D46F", bold=True)


def load_model(model_name):
with suppress_outputs():
model = SentenceTransformer(model_name)
return model


def load_meta(meta_file):
with open(meta_file, 'r') as f:
return _db.fetch_all_documents()
with open(meta_file, "r") as f:
return json.loads(f.read())


def embed(items, model):
return model.encode(items)



def load_meta_emb(meta, model):
return embed([doc['description'] for doc in meta], model)
return embed([doc["description"] for doc in meta], model)


def get_meta_match(meta_emb, q_emb, model):
match = torch.topk(util.pytorch_cos_sim(q_emb, meta_emb), k=1)
match_idx, score = match.indices[0][0].item(), match.values[0][0].item()
return match_idx, score


def ask(question, model, meta, meta_emb):
question_emb = embed(question, model)
match_idx, score = get_meta_match(meta_emb, question_emb, model)
return meta[match_idx], score


def format_response(match: dict, score: float) -> Panel:
table = Table(show_header=False, box=None, padding=(0, 2))
table.add_row(
"[bold blue]Command:[/bold blue]",
match.get('name', 'N/A')
table = Table(show_header=False, box=None, padding=(0, 1), collapse_padding=True)

# add command name
cmd_text = Text.assemble(("Command: ", LABEL_COLORS), (match.get("name", "N/A"), Style(color="white")))
table.add_row(cmd_text)

# add description
if description := match.get("description"):
desc_text = Text.assemble(("Description: ", LABEL_COLORS), (description, Style(color="white")))
table.add_row(desc_text)

# add examples
if examples := match.get("examples"):
examples_group = []
for idx, example in enumerate(examples, start=1):
code = Syntax(
example,
"bash",
theme="monokai",
line_numbers=False,
word_wrap=True,
padding=(0, 2),
)
examples_group.append(Panel(code, title=f"Example {idx}", border_style="cyan"))

# combine all examples into a group for display
examples_panel = Group(*examples_group)
table.add_row(Text("Examples:", style=LABEL_COLORS))
table.add_row(examples_panel)

# create a footer with model name and match score
footer = Text.assemble(
("Model: ", LABEL_COLORS),
(MODEL_NAME, Style(color="white", italic=True)),
(" | Match Score: ", Style(dim=True)),
(f"{score:.2f}", Style(color="white")),
)
if 'examples' in match and match['examples']:
examples_md = "\n".join(f"```bash\n{example}\n```" for example in match['examples'])
table.add_row(
"[bold blue]Examples:[/bold blue]",
""
)
table.add_row(
"",
Markdown(examples_md)
)

return Panel(
table,
title=f"🦦 [bold]Match found (confidence: {score:.2f})[/bold]",
border_style="green"
title=f"{SUIK_LOGO} Match",
subtitle=footer,
border_style="white",
padding=(0, 1),
)


def main():
console.print(Panel.fit("🦦 Suika Loading ...", title="Initializing"))
console.print(Panel.fit(f"{SUIK_LOGO} Suika Loading ...", title="Initializing"))
model = load_model(MODEL_NAME)
meta = load_meta(META_FILE)
meta_emb = load_meta_emb(meta, model)

console.print(Panel.fit("✨ Ask me any linux commands (type 'exit' to quit)", title="🦦 Suika Ready"))
console.print(Panel.fit("✨ Ask me any linux commands (type 'exit' to quit)", title=f"{SUIK_LOGO} Suika Ready"))
while True:
question = typer.prompt("\nYour question")
if question.lower() == 'exit':
console.print("\n🦦 👋 Goodbye!")
break
question = Prompt.ask(f"Your question {SUIK_LOGO} ")
# question = typer.prompt("Your question")
if question.lower() == "exit":
console.print(f"{SUIK_LOGO} 👋 Goodbye!")
break
try:
match, score = ask(question, model, meta, meta_emb)
response_panel = format_response(match, score)
console.print(response_panel)
feedback = typer.confirm("\n🦦 Was this response helpful?")
if not feedback:
console.print("[yellow]🦦 I'm sorry the response wasn't helpful. Please try rephrasing your question.[/yellow]")
except Exception as e:
console.print(f"[red]🦦 An error occurred: {str(e)}[/red]")
console.print("🦦 Please try again with a different question.")
console.print(f"[red]{SUIK_LOGO} An error occurred: {str(e)}[/red]")
console.print(f"{SUIK_LOGO} Please try again with a different question.")


if __name__ == "__main__":
typer.run(main)
typer.run(main)
86 changes: 86 additions & 0 deletions db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import sqlite3
from copy import deepcopy
import json

DATABASE_NAME = "suika_commands.db"
TABLE = "suika_commands"


def create_database():
conn = sqlite3.connect(DATABASE_NAME)
try:
cursor = conn.cursor()
cursor.execute(f"""
CREATE TABLE IF NOT EXISTS {TABLE} (
linux_cmd_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT NOT NULL,
syntax TEXT NOT NULL,
keywords TEXT NOT NULL,
examples TEXT NOT NULL
)
""")
conn.commit()
finally:
conn.close()


def insert_doc(doc, cursor):
if not doc.get("name"):
return
# convert list into json string
doc["linux_cmd_id"] = f"linux_{doc['name']}"
doc["keywords"] = json.dumps(doc["keywords"])
doc["examples"] = json.dumps(doc["examples"])
try:
cursor.execute(
f"""
INSERT INTO {TABLE} (name, linux_cmd_id, description, syntax, keywords, examples)
VALUES (:name, :linux_cmd_id, :description, :syntax, :keywords, :examples)
""",
doc,
)
except sqlite3.IntegrityError as e:
print(f"integrityError: skipping document {doc['linux_cmd_id']} due to unique constraint violation.")


def load_data(docs):
try:
conn = sqlite3.connect(DATABASE_NAME)
cursor = conn.cursor()
for doc in docs:
doc = deepcopy(doc)
insert_doc(doc, cursor)
conn.commit()
except sqlite3.OperationalError as e:
print(f"operationalError: {e}")
finally:
conn.close()


def query_by_id(doc_id):
conn = sqlite3.connect(DATABASE_NAME)
cursor = conn.cursor()
cursor.execute("SELECT * FROM suika_commands WHERE linux_cmd_id = ?", (doc_id,))
res = cursor.fetchone()
conn.close()
return res


def fetch_all_documents():
try:
conn = sqlite3.connect(DATABASE_NAME)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM {TABLE}")
rows = cursor.fetchall()
documents = [dict(row) for row in rows]
for doc in documents:
doc["keywords"] = json.loads(doc["keywords"])
doc["examples"] = json.loads(doc["examples"])
except sqlite3.OperationalError as e:
print(f"OperationalError: {e}")
documents = []
finally:
conn.close()
return documents
Loading

0 comments on commit cd9fe2d

Please sign in to comment.