Skip to content

Commit

Permalink
make myvanna file to clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
aran-nakayama committed Oct 3, 2024
1 parent 0d63e06 commit ffae625
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 11 deletions.
3 changes: 1 addition & 2 deletions examples/use-rdb-resource/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from openssa import DANA, DbResource
from dotenv import load_dotenv

# .envファイルの読み込み
load_dotenv()

def get_or_create_agent(query) -> DANA:
Expand All @@ -19,7 +18,7 @@ def solve(question, query) -> str:

if __name__ == '__main__':
QUESTION = (
"What is the best-selling product in the last year from sales_data table?"
"Can you list the products in order of sales volume from highest to lowest?"
)

query = generate_sql_from_prompt(QUESTION)
Expand Down
4 changes: 0 additions & 4 deletions examples/use-rdb-resource/make_example_table_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
from myvanna import train_vanna_for_sales_data

# .envファイルの読み込み
load_dotenv()

Base = declarative_base()
Expand Down Expand Up @@ -44,7 +43,6 @@ def drop_table(self, table_class):
if inspector.has_table(table_class.__tablename__):
table_class.__table__.drop(self.engine)

# データ生成
fake = Faker()
seed_value = 42
random.seed(seed_value)
Expand Down Expand Up @@ -85,7 +83,6 @@ def generate_sales_data(session, num_records):
session = db.get_session()

generate_sales_data(session, 20000)
print("20000件のデータがsales_dataテーブルに作成されました。")

train_vanna_for_sales_data("""
CREATE TABLE sales_data (
Expand All @@ -96,4 +93,3 @@ def generate_sales_data(session, num_records):
region VARCHAR(255)
)
""")
print("vannaをsales_dataに合わせて訓練しました。")
7 changes: 2 additions & 5 deletions examples/use-rdb-resource/myvanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,26 @@
from vanna.openai import OpenAI_Chat
from vanna.chromadb import ChromaDB_VectorStore

# .envファイルの読み込み
load_dotenv()

# 環境変数から接続情報を取得
db_user = os.getenv('DB_USERNAME')
db_password = os.getenv('DB_PASSWORD')
db_host = os.getenv('DB_HOST')
db_port = int(os.getenv('DB_PORT'))
db_database = os.getenv('DB_NAME')
openai_api_key = os.getenv('OPENAI_API_KEY')

# MyVannaクラス定義
class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)

# sales_dataに基づいてVannaを訓練する関数

def train_vanna_for_sales_data(ddl):
vn_openai = MyVanna(config={'model': 'gpt-4o', 'api_key': openai_api_key})
vn_openai.train(ddl=ddl)

# プロンプトからSQLを生成する関数

def generate_sql_from_prompt(question) -> str:
vn_openai = MyVanna(config={'model': 'gpt-4o', 'api_key': openai_api_key})
vn_openai.connect_to_mysql(host=db_host, dbname=db_database, user=db_user, password=db_password, port=db_port)
Expand Down

0 comments on commit ffae625

Please sign in to comment.