-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
24,737 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
.ipynb_checkpoints | ||
__pycache__ | ||
tldr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,113 +1,159 @@ | ||
import typer | ||
from sentence_transformers import SentenceTransformer | ||
from sentence_transformers import SentenceTransformer, util | ||
import os | ||
import torch | ||
import torch._dynamo | ||
import json | ||
from rich import print | ||
from rich.console import Console | ||
from rich.panel import Panel | ||
from rich.table import Table | ||
from rich.markdown import Markdown | ||
import warnings | ||
import logging | ||
import contextlib | ||
from PIL import Image | ||
import db as _db | ||
|
||
from rich.panel import Panel | ||
from rich.table import Table | ||
from rich.markdown import Markdown | ||
from rich.console import Group | ||
from rich.syntax import Syntax | ||
from rich.style import Style | ||
from rich.text import Text | ||
from rich.prompt import Prompt | ||
|
||
|
||
warnings.filterwarnings('ignore') | ||
# supress warnings | ||
warnings.filterwarnings("ignore") | ||
logging.getLogger().setLevel(logging.CRITICAL) | ||
torch._dynamo.config.suppress_errors = True | ||
logging.getLogger("torch").setLevel(logging.CRITICAL) | ||
logging.getLogger("torch._dynamo").setLevel(logging.CRITICAL) | ||
logging.getLogger("triton").setLevel(logging.CRITICAL) | ||
|
||
MODEL_NAME = "answerdotai/ModernBERT-base" | ||
META_FILE = 'meta.json' | ||
console = Console() | ||
|
||
@contextlib.contextmanager | ||
def suppress_outputs(): | ||
"""Context manager to suppress all outputs, warnings, and errors temporarily""" | ||
with open(os.devnull, 'w') as devnull: | ||
with open(os.devnull, "w") as devnull: | ||
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull): | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
yield | ||
|
||
|
||
# model infos | ||
MODEL_NAME = "answerdotai/ModernBERT-base" | ||
MODEL_NAME = "all-mpnet-base-v2" | ||
SUIK_LOGO = "🦦" | ||
|
||
|
||
# agent config | ||
META_FILE = "meta.json" | ||
console = Console() | ||
LABEL_COLORS = Style(color="#B7D46F", bold=True) | ||
|
||
|
||
def load_model(model_name): | ||
with suppress_outputs(): | ||
model = SentenceTransformer(model_name) | ||
return model | ||
|
||
|
||
def load_meta(meta_file): | ||
with open(meta_file, 'r') as f: | ||
return _db.fetch_all_documents() | ||
with open(meta_file, "r") as f: | ||
return json.loads(f.read()) | ||
|
||
|
||
def embed(items, model): | ||
return model.encode(items) | ||
|
||
|
||
|
||
def load_meta_emb(meta, model): | ||
return embed([doc['description'] for doc in meta], model) | ||
return embed([doc["description"] for doc in meta], model) | ||
|
||
|
||
def get_meta_match(meta_emb, q_emb, model): | ||
match = torch.topk(util.pytorch_cos_sim(q_emb, meta_emb), k=1) | ||
match_idx, score = match.indices[0][0].item(), match.values[0][0].item() | ||
return match_idx, score | ||
|
||
|
||
def ask(question, model, meta, meta_emb): | ||
question_emb = embed(question, model) | ||
match_idx, score = get_meta_match(meta_emb, question_emb, model) | ||
return meta[match_idx], score | ||
|
||
|
||
def format_response(match: dict, score: float) -> Panel: | ||
table = Table(show_header=False, box=None, padding=(0, 2)) | ||
table.add_row( | ||
"[bold blue]Command:[/bold blue]", | ||
match.get('name', 'N/A') | ||
table = Table(show_header=False, box=None, padding=(0, 1), collapse_padding=True) | ||
|
||
# add command name | ||
cmd_text = Text.assemble(("Command: ", LABEL_COLORS), (match.get("name", "N/A"), Style(color="white"))) | ||
table.add_row(cmd_text) | ||
|
||
# add description | ||
if description := match.get("description"): | ||
desc_text = Text.assemble(("Description: ", LABEL_COLORS), (description, Style(color="white"))) | ||
table.add_row(desc_text) | ||
|
||
# add examples | ||
if examples := match.get("examples"): | ||
examples_group = [] | ||
for idx, example in enumerate(examples, start=1): | ||
code = Syntax( | ||
example, | ||
"bash", | ||
theme="monokai", | ||
line_numbers=False, | ||
word_wrap=True, | ||
padding=(0, 2), | ||
) | ||
examples_group.append(Panel(code, title=f"Example {idx}", border_style="cyan")) | ||
|
||
# combine all examples into a group for display | ||
examples_panel = Group(*examples_group) | ||
table.add_row(Text("Examples:", style=LABEL_COLORS)) | ||
table.add_row(examples_panel) | ||
|
||
# create a footer with model name and match score | ||
footer = Text.assemble( | ||
("Model: ", LABEL_COLORS), | ||
(MODEL_NAME, Style(color="white", italic=True)), | ||
(" | Match Score: ", Style(dim=True)), | ||
(f"{score:.2f}", Style(color="white")), | ||
) | ||
if 'examples' in match and match['examples']: | ||
examples_md = "\n".join(f"```bash\n{example}\n```" for example in match['examples']) | ||
table.add_row( | ||
"[bold blue]Examples:[/bold blue]", | ||
"" | ||
) | ||
table.add_row( | ||
"", | ||
Markdown(examples_md) | ||
) | ||
|
||
return Panel( | ||
table, | ||
title=f"🦦 [bold]Match found (confidence: {score:.2f})[/bold]", | ||
border_style="green" | ||
title=f"{SUIK_LOGO} Match", | ||
subtitle=footer, | ||
border_style="white", | ||
padding=(0, 1), | ||
) | ||
|
||
|
||
def main(): | ||
console.print(Panel.fit("🦦 Suika Loading ...", title="Initializing")) | ||
console.print(Panel.fit(f"{SUIK_LOGO} Suika Loading ...", title="Initializing")) | ||
model = load_model(MODEL_NAME) | ||
meta = load_meta(META_FILE) | ||
meta_emb = load_meta_emb(meta, model) | ||
|
||
console.print(Panel.fit("✨ Ask me any linux commands (type 'exit' to quit)", title="🦦 Suika Ready")) | ||
console.print(Panel.fit("✨ Ask me any linux commands (type 'exit' to quit)", title=f"{SUIK_LOGO} Suika Ready")) | ||
while True: | ||
question = typer.prompt("\nYour question") | ||
if question.lower() == 'exit': | ||
console.print("\n🦦 👋 Goodbye!") | ||
break | ||
question = Prompt.ask(f"Your question {SUIK_LOGO} ") | ||
# question = typer.prompt("Your question") | ||
if question.lower() == "exit": | ||
console.print(f"{SUIK_LOGO} 👋 Goodbye!") | ||
break | ||
try: | ||
match, score = ask(question, model, meta, meta_emb) | ||
response_panel = format_response(match, score) | ||
console.print(response_panel) | ||
feedback = typer.confirm("\n🦦 Was this response helpful?") | ||
if not feedback: | ||
console.print("[yellow]🦦 I'm sorry the response wasn't helpful. Please try rephrasing your question.[/yellow]") | ||
except Exception as e: | ||
console.print(f"[red]🦦 An error occurred: {str(e)}[/red]") | ||
console.print("🦦 Please try again with a different question.") | ||
console.print(f"[red]{SUIK_LOGO} An error occurred: {str(e)}[/red]") | ||
console.print(f"{SUIK_LOGO} Please try again with a different question.") | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(main) | ||
typer.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import sqlite3 | ||
from copy import deepcopy | ||
import json | ||
|
||
DATABASE_NAME = "suika_commands.db" | ||
TABLE = "suika_commands" | ||
|
||
|
||
def create_database(): | ||
conn = sqlite3.connect(DATABASE_NAME) | ||
try: | ||
cursor = conn.cursor() | ||
cursor.execute(f""" | ||
CREATE TABLE IF NOT EXISTS {TABLE} ( | ||
linux_cmd_id TEXT PRIMARY KEY, | ||
name TEXT NOT NULL, | ||
description TEXT NOT NULL, | ||
syntax TEXT NOT NULL, | ||
keywords TEXT NOT NULL, | ||
examples TEXT NOT NULL | ||
) | ||
""") | ||
conn.commit() | ||
finally: | ||
conn.close() | ||
|
||
|
||
def insert_doc(doc, cursor): | ||
if not doc.get("name"): | ||
return | ||
# convert list into json string | ||
doc["linux_cmd_id"] = f"linux_{doc['name']}" | ||
doc["keywords"] = json.dumps(doc["keywords"]) | ||
doc["examples"] = json.dumps(doc["examples"]) | ||
try: | ||
cursor.execute( | ||
f""" | ||
INSERT INTO {TABLE} (name, linux_cmd_id, description, syntax, keywords, examples) | ||
VALUES (:name, :linux_cmd_id, :description, :syntax, :keywords, :examples) | ||
""", | ||
doc, | ||
) | ||
except sqlite3.IntegrityError as e: | ||
print(f"integrityError: skipping document {doc['linux_cmd_id']} due to unique constraint violation.") | ||
|
||
|
||
def load_data(docs): | ||
try: | ||
conn = sqlite3.connect(DATABASE_NAME) | ||
cursor = conn.cursor() | ||
for doc in docs: | ||
doc = deepcopy(doc) | ||
insert_doc(doc, cursor) | ||
conn.commit() | ||
except sqlite3.OperationalError as e: | ||
print(f"operationalError: {e}") | ||
finally: | ||
conn.close() | ||
|
||
|
||
def query_by_id(doc_id): | ||
conn = sqlite3.connect(DATABASE_NAME) | ||
cursor = conn.cursor() | ||
cursor.execute("SELECT * FROM suika_commands WHERE linux_cmd_id = ?", (doc_id,)) | ||
res = cursor.fetchone() | ||
conn.close() | ||
return res | ||
|
||
|
||
def fetch_all_documents(): | ||
try: | ||
conn = sqlite3.connect(DATABASE_NAME) | ||
conn.row_factory = sqlite3.Row | ||
cursor = conn.cursor() | ||
cursor.execute(f"SELECT * FROM {TABLE}") | ||
rows = cursor.fetchall() | ||
documents = [dict(row) for row in rows] | ||
for doc in documents: | ||
doc["keywords"] = json.loads(doc["keywords"]) | ||
doc["examples"] = json.loads(doc["examples"]) | ||
except sqlite3.OperationalError as e: | ||
print(f"OperationalError: {e}") | ||
documents = [] | ||
finally: | ||
conn.close() | ||
return documents |
Oops, something went wrong.