Skip to content

Commit 8397365

Browse files
committed
feat: add utils dir
1 parent c589e4a commit 8397365

File tree

8 files changed

+76
-72
lines changed

8 files changed

+76
-72
lines changed

app.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,13 @@
1-
import os
1+
from config import DB_URL
22

33
import streamlit as st
4-
from dotenv import load_dotenv
54
from langchain.chains import create_sql_query_chain
65
from langchain_community.utilities import SQLDatabase
76

8-
from llm_util import get_llm
9-
from rag_util import get_vector_store, run_rag
10-
from sql_util import clean_sql_response, convert_result_to_df
7+
from utils.llm_util import get_llm
8+
from utils.rag_util import get_vector_store, run_rag
9+
from utils.sql_util import clean_sql_response, convert_result_to_df
1110

12-
# 讀取 .env 變數
13-
load_dotenv()
14-
15-
16-
# 取得資料庫連線字串
17-
DB_URL = os.getenv("DB_URL")
18-
if not DB_URL:
19-
raise Exception("未在 .env 檔案中找到 DB_URL")
2011

2112
MAX_RETRIES = 3 # 最多重試次數
2213

@@ -43,7 +34,6 @@
4334
user_input = st.chat_input("請輸入您的問題...")
4435

4536
if user_input:
46-
4737
# 顯示使用者輸入
4838
with st.chat_message("user"):
4939
st.markdown(user_input)
@@ -52,6 +42,7 @@
5242
query_result = None
5343
memory = []
5444

45+
# 嘗試 MAX_RETRIES 次
5546
for retry in range(1, MAX_RETRIES + 1):
5647
try:
5748
# 生成 SQL 查詢
@@ -67,12 +58,16 @@
6758
break
6859

6960
except Exception as e:
61+
# 將錯誤轉為字串
7062
error_message = str(e)
63+
64+
# 記錄錯誤到 memory 中已供下次使用
7165
memory.append({
7266
"sql": sql_query,
7367
"error": error_message,
7468
})
7569

70+
# 顯示在 Streamlit 上
7671
if retry < MAX_RETRIES:
7772
with st.chat_message("assistant"):
7873
st.markdown(
@@ -90,7 +85,7 @@
9085

9186
# 順利產生 sql 後嘗試執行
9287
if sql_query:
93-
88+
# 顯示在 Streamlit 上
9489
with st.chat_message("assistant"):
9590
st.markdown(f"**生成的 SQL 查詢:**\n```sql\n{sql_query}\n```")
9691

@@ -105,9 +100,11 @@
105100
with st.chat_message("table"):
106101
st.dataframe(result_df)
107102
else:
103+
# 沒有資料
108104
with st.chat_message("table"):
109105
st.markdown(f"⚠️ 沒有查詢結果。")
110106

107+
# 轉換錯誤
111108
except Exception as e:
112109
with st.chat_message("table"):
113110
st.markdown(f"❌ 結果處理錯誤:{e}")

config.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
from dotenv import load_dotenv
3+
4+
5+
6+
# 讀取 .env 變數
7+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
8+
load_dotenv(dotenv_path=os.path.join(BASE_DIR, ".env"))
9+
10+
# 資料庫連線
11+
DB_URL = os.getenv("DB_URL")
12+
if not DB_URL:
13+
raise Exception("未在 .env 檔案中找到 DB_URL。")
14+
15+
# LLM 模型
16+
LLM_TYPE = os.getenv("LLM_TYPE", "OPENAI") # 默認為 OPENAI
17+
18+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
19+
OPENAI_MODEL = os.getenv("OPENAI_MODEL")
20+
21+
OLLAMA_URL = os.getenv("OLLAMA_URL")
22+
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL")
23+
24+
if LLM_TYPE == "OPENAI":
25+
if not OPENAI_API_KEY:
26+
raise Exception("未在 .env 檔案中找到 OPENAI_API_KEY。")
27+
if not OPENAI_MODEL:
28+
raise Exception("未在 .env 檔案中找到 OPENAI_MODEL")
29+
elif LLM_TYPE == "OLLAMA":
30+
if not OLLAMA_URL:
31+
raise Exception("未在 .env 檔案中找到 OLLAMA_URL")
32+
if not OLLAMA_MODEL:
33+
raise Exception("未在 .env 檔案中找到 OLLAMA_MODEL")
34+
35+
# Embedding
36+
OLLAMA_EMBEDDING_URL = os.getenv("OLLAMA_EMBEDDING_URL")
37+
if not OLLAMA_EMBEDDING_URL:
38+
raise Exception("未在 .env 檔案中找到 OLLAMA_EMBEDDING_URL。")
39+
OLLAMA_EMBEDDING_MODEL = os.getenv("OLLAMA_EMBEDDING_MODEL")
40+
if not OLLAMA_EMBEDDING_MODEL:
41+
raise Exception("未在 .env 檔案中找到 OLLAMA_EMBEDDING_MODEL。")

llm_util.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

script.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from langchain.chains import create_sql_query_chain
55
from langchain_community.utilities import SQLDatabase
66

7-
from llm_util import get_llm
8-
from prompt_util import get_prompt
9-
from sql_util import clean_sql_response, convert_result_to_df
7+
from utils.llm_util import get_llm
8+
from utils.prompt_util import get_prompt
9+
from utils.sql_util import clean_sql_response, convert_result_to_df
1010

1111
# 讀取 .env 變數
1212
load_dotenv()

utils/llm_util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from config import LLM_TYPE, OLLAMA_MODEL, OLLAMA_URL, OPENAI_API_KEY, OPENAI_MODEL
2+
3+
from langchain_ollama.chat_models import ChatOllama
4+
from langchain_openai import ChatOpenAI
5+
6+
7+
8+
def get_llm():
9+
# 初始化 LLM 模型
10+
if LLM_TYPE == "OPENAI":
11+
return ChatOpenAI(model=OPENAI_MODEL, temperature=0, api_key=OPENAI_API_KEY)
12+
elif LLM_TYPE == "OLLAMA":
13+
return ChatOllama(model=OLLAMA_MODEL, base_url=OLLAMA_URL)
14+
15+
else:
16+
raise Exception(f"未支援的 LLM_TYPE: {LLM_TYPE}")
17+
File renamed without changes.

rag_util.py renamed to utils/rag_util.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,15 @@
1-
import json
2-
import os
1+
from config import OLLAMA_EMBEDDING_MODEL, OLLAMA_EMBEDDING_URL
32

3+
import json
44
from langchain_ollama import OllamaEmbeddings
5-
from dotenv import load_dotenv
65
from langchain_community.document_loaders import DirectoryLoader
76
from langchain_core.documents import Document
87
from langchain_core.vectorstores import InMemoryVectorStore
98

10-
from prompt_util import get_prompt
9+
from utils.prompt_util import get_prompt
1110

1211

13-
# 讀取 .env 變數
14-
load_dotenv()
15-
1612
def get_vector_store():
17-
# embedding
18-
OLLAMA_EMBEDDING_URL = os.getenv("OLLAMA_EMBEDDING_URL")
19-
if not OLLAMA_EMBEDDING_URL:
20-
raise Exception("未在 .env 檔案中找到 OLLAMA_EMBEDDING_URL。")
21-
OLLAMA_EMBEDDING_MODEL = os.getenv("OLLAMA_EMBEDDING_MODEL")
22-
if not OLLAMA_EMBEDDING_MODEL:
23-
raise Exception("未在 .env 檔案中找到 OLLAMA_EMBEDDING_MODEL。")
24-
2513
embeddings = OllamaEmbeddings(model=OLLAMA_EMBEDDING_MODEL, base_url=OLLAMA_EMBEDDING_URL)
2614

2715
vector_store = InMemoryVectorStore(embeddings)
File renamed without changes.

0 commit comments

Comments
 (0)