Skip to content

Commit

Permalink
vector db setup added
Browse files Browse the repository at this point in the history
  • Loading branch information
arun477 committed Dec 24, 2024
1 parent 82281b2 commit f17aebd
Show file tree
Hide file tree
Showing 8 changed files with 668 additions and 30 deletions.
17 changes: 12 additions & 5 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import db as _db
import numpy as np
from tqdm import tqdm
import vec_db as _vec_db

from rich.panel import Panel
from rich.table import Table
Expand Down Expand Up @@ -47,7 +48,7 @@ def suppress_outputs():
# model infos
MODEL_NAME = "answerdotai/ModernBERT-large"
# MODEL_NAME = "answerdotai/ModernBERT-base"
# MODEL_NAME = "all-mpnet-base-v2"
MODEL_NAME = "all-mpnet-base-v2"
SUIK_LOGO = "🦦"


Expand All @@ -65,7 +66,7 @@ def load_model(model_name):


def load_meta(meta_file):
# return _db.fetch_all_documents()
return _db.fetch_all_documents()
with open(meta_file, "r") as f:
return json.loads(f.read())

Expand Down Expand Up @@ -130,19 +131,25 @@ def format_response(match: dict, score: float) -> Panel:
)


def vector_db_ask(query, model):
query_emb = embed(query, model)
return _vec_db.search(query_emb)


def main():
console.print(Panel.fit(f"{SUIK_LOGO} Suika Loading ...", title="Initializing"))
model = load_model(MODEL_NAME)
meta = load_meta(META_FILE)
meta_emb = load_meta_emb(meta, model)
# meta = load_meta(META_FILE)
# meta_emb = load_meta_emb(meta, model)
console.print(Panel.fit("✨ Ask me any linux commands (type 'exit' to quit)", title=f"{SUIK_LOGO} Suika Ready"))
while True:
question = Prompt.ask(f"[#B7D46F]❯ Your question {SUIK_LOGO} ")
if question.lower() == "exit":
console.print(f"{SUIK_LOGO} 👋 Goodbye!")
break
try:
match, score = ask(question, model, meta, meta_emb)
# match, score = ask(question, model, meta, meta_emb)
match, score = vector_db_ask(question, model)
response_panel = format_response(match, score)
console.print(response_panel)
except Exception as e:
Expand Down
23 changes: 0 additions & 23 deletions banner.txt

This file was deleted.

27 changes: 25 additions & 2 deletions db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

DATABASE_NAME = "suika_commands.db"
TABLE = "suika_commands"
VEC_ID_MAPPING_TABLE = "id_to_vec_pos"


def create_database():
Expand Down Expand Up @@ -59,12 +60,17 @@ def load_data(docs):


def query_by_id(doc_id):
print("doc_id", doc_id)
conn = sqlite3.connect(DATABASE_NAME)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("SELECT * FROM suika_commands WHERE linux_cmd_id = ?", (doc_id,))
res = cursor.fetchone()
doc = dict(res)
conn.close()
return res
doc["keywords"] = json.loads(doc["keywords"])
doc["examples"] = json.loads(doc["examples"])
return doc


def fetch_all_documents():
Expand All @@ -83,4 +89,21 @@ def fetch_all_documents():
documents = []
finally:
conn.close()
return documents
return documents


def fetch_vec_id_mapping():
try:
conn = sqlite3.connect(DATABASE_NAME)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM {VEC_ID_MAPPING_TABLE}")
rows = cursor.fetchall()
documents = [dict(row) for row in rows]
mapping = {ele["faiss_index_id"]: ele["linux_cmd_id"] for ele in documents}
except sqlite3.OperationalError as e:
print(f"OperationalError: {e}")
mapping = {}
finally:
conn.close()
return mapping
264 changes: 264 additions & 0 deletions model_eval.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"id": "03c08f67-c3e0-4ab2-a517-e9215508b9ec",
"metadata": {},
"outputs": [],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"from sentence_transformers.util import pytorch_cos_sim\n",
"from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, InformationRetrievalEvaluator, NanoBEIREvaluator\n",
"import torch\n",
"import datasets\n",
"import random\n",
"torch._dynamo.config.suppress_errors = True\n",
"\n",
"MODEL = 'answerdotai/ModernBERT-base'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2d8b5fa4-8579-47ef-82b5-ba5f4584a658",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No sentence-transformers model found with name answerdotai/ModernBERT-base. Creating a new one with mean pooling.\n"
]
}
],
"source": [
"model = SentenceTransformer(MODEL)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "33f037f9-3731-4f27-bdd8-69e59ed3b8d0",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'pearson_cosine': 0.5190741868883064, 'spearman_cosine': 0.5566359148742774}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# semantic score\n",
"data = datasets.load_dataset(\"sentence-transformers/stsb\", split='validation')\n",
"evaluator = EmbeddingSimilarityEvaluator(sentences1=data['sentence1'], sentences2=data['sentence2'], scores=data['score'])\n",
"result = evaluator(model)\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "f31e5b5a-f26c-4813-ba7b-5b088b35419e",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# information reterival eval\n",
"datasets = [\"QuoraRetrieval\", \"MSMARCO\"]\n",
"query_prompts = {\n",
" \"QuoraRetrieval\": \"Instruct: Given a question, retrieve questions that are semantically equivalent to the given question\\nQuery: \",\n",
" \"MSMARCO\": \"Instruct: Given a web search query, retrieve relevant passages that answer the query\\nQuery: \"\n",
"}\n",
"ir_evaluator = NanoBEIREvaluator(dataset_names=datasets, query_prompts=query_prompts)\n",
"result = ir_evaluator(model)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "67d05518-9ad6-4cdd-822c-7f4d2cbc11ca",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'NanoQuoraRetrieval_cosine_accuracy@1': 0.04,\n",
" 'NanoQuoraRetrieval_cosine_accuracy@3': 0.12,\n",
" 'NanoQuoraRetrieval_cosine_accuracy@5': 0.12,\n",
" 'NanoQuoraRetrieval_cosine_accuracy@10': 0.16,\n",
" 'NanoQuoraRetrieval_cosine_precision@1': 0.04,\n",
" 'NanoQuoraRetrieval_cosine_precision@3': 0.04,\n",
" 'NanoQuoraRetrieval_cosine_precision@5': 0.024000000000000004,\n",
" 'NanoQuoraRetrieval_cosine_precision@10': 0.016,\n",
" 'NanoQuoraRetrieval_cosine_recall@1': 0.04,\n",
" 'NanoQuoraRetrieval_cosine_recall@3': 0.12,\n",
" 'NanoQuoraRetrieval_cosine_recall@5': 0.12,\n",
" 'NanoQuoraRetrieval_cosine_recall@10': 0.15,\n",
" 'NanoQuoraRetrieval_cosine_ndcg@10': 0.09591463641493617,\n",
" 'NanoQuoraRetrieval_cosine_mrr@10': 0.07916666666666666,\n",
" 'NanoQuoraRetrieval_cosine_map@100': 0.08583165716368223,\n",
" 'NanoMSMARCO_cosine_accuracy@1': 0.0,\n",
" 'NanoMSMARCO_cosine_accuracy@3': 0.0,\n",
" 'NanoMSMARCO_cosine_accuracy@5': 0.0,\n",
" 'NanoMSMARCO_cosine_accuracy@10': 0.0,\n",
" 'NanoMSMARCO_cosine_precision@1': 0.0,\n",
" 'NanoMSMARCO_cosine_precision@3': 0.0,\n",
" 'NanoMSMARCO_cosine_precision@5': 0.0,\n",
" 'NanoMSMARCO_cosine_precision@10': 0.0,\n",
" 'NanoMSMARCO_cosine_recall@1': 0.0,\n",
" 'NanoMSMARCO_cosine_recall@3': 0.0,\n",
" 'NanoMSMARCO_cosine_recall@5': 0.0,\n",
" 'NanoMSMARCO_cosine_recall@10': 0.0,\n",
" 'NanoMSMARCO_cosine_ndcg@10': 0.0,\n",
" 'NanoMSMARCO_cosine_mrr@10': 0.0,\n",
" 'NanoMSMARCO_cosine_map@100': 0.0,\n",
" 'NanoBEIR_mean_cosine_accuracy@1': 0.02,\n",
" 'NanoBEIR_mean_cosine_accuracy@3': 0.06,\n",
" 'NanoBEIR_mean_cosine_accuracy@5': 0.06,\n",
" 'NanoBEIR_mean_cosine_accuracy@10': 0.08,\n",
" 'NanoBEIR_mean_cosine_precision@1': 0.02,\n",
" 'NanoBEIR_mean_cosine_precision@3': 0.02,\n",
" 'NanoBEIR_mean_cosine_precision@5': 0.012000000000000002,\n",
" 'NanoBEIR_mean_cosine_precision@10': 0.008,\n",
" 'NanoBEIR_mean_cosine_recall@1': 0.02,\n",
" 'NanoBEIR_mean_cosine_recall@3': 0.06,\n",
" 'NanoBEIR_mean_cosine_recall@5': 0.06,\n",
" 'NanoBEIR_mean_cosine_recall@10': 0.075,\n",
" 'NanoBEIR_mean_cosine_ndcg@10': 0.04795731820746808,\n",
" 'NanoBEIR_mean_cosine_mrr@10': 0.03958333333333333,\n",
" 'NanoBEIR_mean_cosine_map@100': 0.04291582858184111}"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8ef35f0c-1616-4aa2-b415-08af447df206",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7055dc7-30e7-48c1-8653-1639416d48eb",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "4c71bc2c-1c93-47f8-8567-7bdc948e238d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "61885a1b-0ce9-4c90-9fb5-edb1ec2d659a",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "eff31b35-6695-455f-92fd-eb702932d7e5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "31292d8a-365e-4a12-afca-f19a639fff0d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef7f4717-e0ca-4109-a3cf-5fdb6f8fa10f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "c32a9e77-2ee5-469f-af74-7c36d857bf67",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca893ddf-03f8-4fa1-a3ba-089325d9b942",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1d00003-c5ab-4be5-a773-c585c5e2a8e9",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d73ef62a-3e83-4e8f-bbc5-b684853f634d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "m",
"language": "python",
"name": "m"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Binary file modified suika_commands.db
Binary file not shown.
Binary file added suika_commands_vector.index
Binary file not shown.
Loading

0 comments on commit f17aebd

Please sign in to comment.