Skip to content

Commit d99f0cf

Browse files
committed
feat: add plot
1 parent 8397365 commit d99f0cf

File tree

7 files changed

+98
-35
lines changed

7 files changed

+98
-35
lines changed

Diff for: README-CH.md

+1-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ DB_URL='postgresql://user:password@host:5432/database'
4141
### Install Module
4242

4343
```bash
44-
pip install streamlit==1.42.2 pandas==2.2.3 python-dotenv==1.0.1 \
45-
langchain-community==0.3.19 langchain-openai==0.3.7 langchain==0.3.20 \
46-
langchain-ollama==0.2.3 python-magic-bin==0.4.14
44+
pip install -r requirements.txt
4745
```
4846

4947

Diff for: README.md

+1-3
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ Modify the prompt in the code based on your database information. Some templates
4040
### Install Module
4141

4242
```bash
43-
pip install streamlit==1.42.2 pandas==2.2.3 python-dotenv==1.0.1 \
44-
langchain-community==0.3.19 langchain-openai==0.3.7 langchain==0.3.20 \
45-
langchain-ollama==0.2.3 python-magic-bin==0.4.14
43+
pip install -r requirements.txt
4644
```
4745

4846

Diff for: app.py

+82-24
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1-
from config import DB_URL
1+
import re
22

3+
import matplotlib
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
36
import streamlit as st
47
from langchain.chains import create_sql_query_chain
8+
from langchain.schema import HumanMessage
59
from langchain_community.utilities import SQLDatabase
610

11+
from config import DB_URL
712
from utils.llm_util import get_llm
813
from utils.rag_util import get_vector_store, run_rag
914
from utils.sql_util import clean_sql_response, convert_result_to_df
1015

11-
1216
MAX_RETRIES = 3 # 最多重試次數
1317

1418
# 連接資料庫
@@ -32,6 +36,7 @@
3236

3337
# 使用者輸入
3438
user_input = st.chat_input("請輸入您的問題...")
39+
result_df = None
3540

3641
if user_input:
3742
# 顯示使用者輸入
@@ -83,28 +88,81 @@
8388
sql_query = None
8489
break
8590

86-
# 順利產生 sql 後嘗試執行
87-
if sql_query:
88-
# 顯示在 Streamlit 上
89-
with st.chat_message("assistant"):
90-
st.markdown(f"**生成的 SQL 查詢:**\n```sql\n{sql_query}\n```")
91+
# 順利產生 sql 後嘗試執行
92+
if sql_query:
93+
# 顯示在 Streamlit 上
94+
with st.chat_message("assistant"):
95+
st.markdown(f"**生成的 SQL 查詢:**\n```sql\n{sql_query}\n```")
9196

92-
try:
93-
# 將查詢結果轉換成表格
94-
result_df = convert_result_to_df(query_result)
97+
try:
98+
# 將查詢結果轉換成表格
99+
result_df = convert_result_to_df(query_result)
95100

96-
# 顯示查詢結果
97-
if not result_df.empty:
98-
with st.chat_message("assistant"):
99-
st.markdown(f"✅ 查詢成功,結果如下:")
100-
with st.chat_message("table"):
101-
st.dataframe(result_df)
102-
else:
103-
# 沒有資料
104-
with st.chat_message("table"):
105-
st.markdown(f"⚠️ 沒有查詢結果。")
106-
107-
# 轉換錯誤
108-
except Exception as e:
101+
# 顯示查詢結果
102+
if not result_df.empty:
103+
with st.chat_message("assistant"):
104+
st.markdown(f"✅ 查詢成功,結果如下:")
105+
with st.chat_message("table"):
106+
st.dataframe(result_df)
107+
else:
108+
# 沒有資料
109109
with st.chat_message("table"):
110-
st.markdown(f"❌ 結果處理錯誤:{e}")
110+
st.markdown(f"⚠️ 沒有查詢結果。")
111+
112+
# 轉換錯誤
113+
except Exception as e:
114+
with st.chat_message("table"):
115+
st.markdown(f"❌ 結果處理錯誤:{e}")
116+
117+
118+
# 順利產生 result_df 後嘗試執行
119+
if user_input and result_df is not None and not result_df.empty:
120+
121+
# 讓 LLM 生成 Matplotlib 程式碼
122+
plot_prompt = f"""
123+
請根據 使用者問題 '{user_input}' 以及 Pandas DataFrame `result_df` 繪製 Matplotlib 圖表:
124+
{result_df.head().to_string()}
125+
126+
並確保符合以下要求:
127+
128+
1. **請嚴格遵守:**
129+
- **只能使用 `pd` 和 `matplotlib` 和 `plt` 套件,且不需要寫任何 import ,不允許操作其他套件。
130+
- **禁止** 重新定義 `result_df`,不要包含 `result_df = ...` 或 `pd.DataFrame({...})`。
131+
- **請直接使用** `result_df`,假設它已經存在且包含完整數據,不需要創建新數據。
132+
- **不得修改 `result_df` 的數據**,只能轉換格式(如 `pd.to_datetime()`)。
133+
134+
2. **程式碼格式:**
135+
- **請使用 `fig, ax = plt.subplots()`** 來建立圖表。
136+
- **請使用 `ax.plot()`,而不是 `plt.plot()`。**
137+
- **請確保有產生 `fig` 變數,但不需要返回 `fig` 變數,不要使用 `plt.show()`**。
138+
139+
3. **圖表設定:**
140+
- 設定適當的 `figsize`,長度設定為 8。
141+
- 使用 `ax.grid(True)` 啟用網格線。
142+
- 設定合適的標題與標籤,請確保 `plt` 參數有設定中文字型,才能顯示中文,使用 `plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']` 。
143+
- **只有在資料包含時間序列時才可使用折線圖**。例如,顯示時間變化的數據,如每日銷售、溫度隨時間變化等,才能使用折線圖。
144+
- **如果資料中沒有時間變數,請避免使用折線圖**,而應選擇其他適合的圖表類型,例如條形圖、圓餅圖等。
145+
- 請確保根據資料的性質選擇正確的圖表。折線圖僅用於顯示時間序列數據,其他情境請選擇其他圖表形式。
146+
- **如果 X 軸的標籤過多,請調整顯示方式**,例如將 X 軸標籤旋轉為斜體(例如 45 度角),以避免文字擠在一起。若標籤仍然過於擠密,考慮將標籤改為垂直顯示,或僅顯示部分標籤(例如每隔一個顯示一個標籤),以提高可讀性。
147+
148+
**請只有 Python 程式碼,**不要包含任何說明或解釋!**
149+
**請確保程式碼沒有 `result_df = ...`,否則無效!**
150+
**請確保程式碼沒有 return 任何東西!**
151+
"""
152+
153+
# 取得 LLM 生成的 Python 程式碼
154+
generated_code = llm([HumanMessage(content=plot_prompt)]).content
155+
156+
# 使用正則表達式移除 Markdown 程式碼
157+
clean_code = re.sub(r"```(python)?\n", "", generated_code).strip()
158+
clean_code = re.sub(r"\n```", "", clean_code)
159+
160+
plot_module = {"matplotlib": matplotlib, 'plt': plt, "pd": pd, "result_df": result_df}
161+
# 執行生成的程式碼
162+
exec(clean_code, plot_module)
163+
164+
# 從 `plot_module` 取出 `fig`
165+
fig = plot_module.get("fig")
166+
167+
# 顯示圖表到 Streamlit
168+
st.pyplot(fig=fig, use_container_width=False)

Diff for: requirements.txt

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
streamlit==1.42.2
2+
pandas==2.2.3
3+
python-dotenv==1.0.1
4+
langchain-community==0.3.19
5+
langchain-openai==0.3.7
6+
langchain==0.3.20
7+
langchain-ollama==0.2.3
8+
python-magic-bin==0.4.14
9+
matplotlib

Diff for: script.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from langchain_community.utilities import SQLDatabase
66

77
from utils.llm_util import get_llm
8-
from utils.prompt_util import get_prompt
8+
from utils.prompt_util import get_sql_prompt
99
from utils.sql_util import clean_sql_response, convert_result_to_df
1010

1111
# 讀取 .env 變數
@@ -32,7 +32,7 @@
3232
llm = get_llm()
3333

3434
# 自訂 Prompt
35-
prompt = get_prompt()
35+
prompt = get_sql_prompt()
3636

3737
# 創建 SQL 查詢鏈
3838
chain = create_sql_query_chain(llm, db, prompt=prompt)

Diff for: utils/prompt_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44

5-
def get_prompt(example: str | None, memory: str | None):
5+
def get_sql_prompt(example: str | None, memory: str | None):
66
"""
77
自訂產生 SQL 的 Prompt
88
"""

Diff for: utils/rag_util.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from langchain_core.documents import Document
77
from langchain_core.vectorstores import InMemoryVectorStore
88

9-
from utils.prompt_util import get_prompt
9+
from utils.prompt_util import get_sql_prompt
1010

1111

1212
def get_vector_store():
@@ -61,7 +61,7 @@ def run_rag(llm, vector_store, user_input, table_info, memory):
6161
"""
6262

6363

64-
prompt = get_prompt(example, memory_str)
64+
prompt = get_sql_prompt(example, memory_str)
6565

6666
# 產生 SQL 查詢
6767
messages = prompt.invoke({

0 commit comments

Comments
 (0)