diff --git a/examples/use-rdb-resource/main.py b/examples/use-rdb-resource/main.py index 26ba635a7..272dcc2b5 100644 --- a/examples/use-rdb-resource/main.py +++ b/examples/use-rdb-resource/main.py @@ -2,7 +2,6 @@ from openssa import DANA, DbResource from dotenv import load_dotenv -# .envファイルの読み込み load_dotenv() def get_or_create_agent(query) -> DANA: @@ -19,7 +18,7 @@ def solve(question, query) -> str: if __name__ == '__main__': QUESTION = ( - "What is the best-selling product in the last year from sales_data table?" + "Can you list the products in order of sales volume from highest to lowest?" ) query = generate_sql_from_prompt(QUESTION) diff --git a/examples/use-rdb-resource/make_example_table_data.py b/examples/use-rdb-resource/make_example_table_data.py index 224f5ad5f..f5faa1bc3 100644 --- a/examples/use-rdb-resource/make_example_table_data.py +++ b/examples/use-rdb-resource/make_example_table_data.py @@ -6,7 +6,6 @@ import os from myvanna import train_vanna_for_sales_data -# .envファイルの読み込み load_dotenv() Base = declarative_base() @@ -44,7 +43,6 @@ def drop_table(self, table_class): if inspector.has_table(table_class.__tablename__): table_class.__table__.drop(self.engine) -# データ生成 fake = Faker() seed_value = 42 random.seed(seed_value) @@ -85,7 +83,6 @@ def generate_sales_data(session, num_records): session = db.get_session() generate_sales_data(session, 20000) - print("20000件のデータがsales_dataテーブルに作成されました。") train_vanna_for_sales_data(""" CREATE TABLE sales_data ( @@ -96,4 +93,3 @@ def generate_sales_data(session, num_records): region VARCHAR(255) ) """) - print("vannaをsales_dataに合わせて訓練しました。") diff --git a/examples/use-rdb-resource/myvanna.py b/examples/use-rdb-resource/myvanna.py index 9f9c90d39..0e544d350 100644 --- a/examples/use-rdb-resource/myvanna.py +++ b/examples/use-rdb-resource/myvanna.py @@ -3,10 +3,8 @@ from vanna.openai import OpenAI_Chat from vanna.chromadb import ChromaDB_VectorStore -# .envファイルの読み込み load_dotenv() -# 環境変数から接続情報を取得 db_user = os.getenv('DB_USERNAME') db_password = os.getenv('DB_PASSWORD') db_host = os.getenv('DB_HOST') @@ -14,18 +12,17 @@ db_database = os.getenv('DB_NAME') openai_api_key = os.getenv('OPENAI_API_KEY') -# MyVannaクラス定義 class MyVanna(ChromaDB_VectorStore, OpenAI_Chat): def __init__(self, config=None): ChromaDB_VectorStore.__init__(self, config=config) OpenAI_Chat.__init__(self, config=config) -# sales_dataに基づいてVannaを訓練する関数 + def train_vanna_for_sales_data(ddl): vn_openai = MyVanna(config={'model': 'gpt-4o', 'api_key': openai_api_key}) vn_openai.train(ddl=ddl) -# プロンプトからSQLを生成する関数 + def generate_sql_from_prompt(question) -> str: vn_openai = MyVanna(config={'model': 'gpt-4o', 'api_key': openai_api_key}) vn_openai.connect_to_mysql(host=db_host, dbname=db_database, user=db_user, password=db_password, port=db_port)