|
1 |
| -from config import DB_URL |
| 1 | +import re |
2 | 2 |
|
| 3 | +import matplotlib |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import pandas as pd |
3 | 6 | import streamlit as st
|
4 | 7 | from langchain.chains import create_sql_query_chain
|
| 8 | +from langchain.schema import HumanMessage |
5 | 9 | from langchain_community.utilities import SQLDatabase
|
6 | 10 |
|
| 11 | +from config import DB_URL |
7 | 12 | from utils.llm_util import get_llm
|
8 | 13 | from utils.rag_util import get_vector_store, run_rag
|
9 | 14 | from utils.sql_util import clean_sql_response, convert_result_to_df
|
10 | 15 |
|
11 |
| - |
12 | 16 | MAX_RETRIES = 3 # 最多重試次數
|
13 | 17 |
|
14 | 18 | # 連接資料庫
|
|
32 | 36 |
|
33 | 37 | # 使用者輸入
|
34 | 38 | user_input = st.chat_input("請輸入您的問題...")
|
| 39 | +result_df = None |
35 | 40 |
|
36 | 41 | if user_input:
|
37 | 42 | # 顯示使用者輸入
|
|
83 | 88 | sql_query = None
|
84 | 89 | break
|
85 | 90 |
|
86 |
| - # 順利產生 sql 後嘗試執行 |
87 |
| - if sql_query: |
88 |
| - # 顯示在 Streamlit 上 |
89 |
| - with st.chat_message("assistant"): |
90 |
| - st.markdown(f"**生成的 SQL 查詢:**\n```sql\n{sql_query}\n```") |
| 91 | +# 順利產生 sql 後嘗試執行 |
| 92 | +if sql_query: |
| 93 | + # 顯示在 Streamlit 上 |
| 94 | + with st.chat_message("assistant"): |
| 95 | + st.markdown(f"**生成的 SQL 查詢:**\n```sql\n{sql_query}\n```") |
91 | 96 |
|
92 |
| - try: |
93 |
| - # 將查詢結果轉換成表格 |
94 |
| - result_df = convert_result_to_df(query_result) |
| 97 | + try: |
| 98 | + # 將查詢結果轉換成表格 |
| 99 | + result_df = convert_result_to_df(query_result) |
95 | 100 |
|
96 |
| - # 顯示查詢結果 |
97 |
| - if not result_df.empty: |
98 |
| - with st.chat_message("assistant"): |
99 |
| - st.markdown(f"✅ 查詢成功,結果如下:") |
100 |
| - with st.chat_message("table"): |
101 |
| - st.dataframe(result_df) |
102 |
| - else: |
103 |
| - # 沒有資料 |
104 |
| - with st.chat_message("table"): |
105 |
| - st.markdown(f"⚠️ 沒有查詢結果。") |
106 |
| - |
107 |
| - # 轉換錯誤 |
108 |
| - except Exception as e: |
| 101 | + # 顯示查詢結果 |
| 102 | + if not result_df.empty: |
| 103 | + with st.chat_message("assistant"): |
| 104 | + st.markdown(f"✅ 查詢成功,結果如下:") |
| 105 | + with st.chat_message("table"): |
| 106 | + st.dataframe(result_df) |
| 107 | + else: |
| 108 | + # 沒有資料 |
109 | 109 | with st.chat_message("table"):
|
110 |
| - st.markdown(f"❌ 結果處理錯誤:{e}") |
| 110 | + st.markdown(f"⚠️ 沒有查詢結果。") |
| 111 | + |
| 112 | + # 轉換錯誤 |
| 113 | + except Exception as e: |
| 114 | + with st.chat_message("table"): |
| 115 | + st.markdown(f"❌ 結果處理錯誤:{e}") |
| 116 | + |
| 117 | + |
| 118 | +# 順利產生 result_df 後嘗試執行 |
| 119 | +if user_input and result_df is not None and not result_df.empty: |
| 120 | + |
| 121 | + # 讓 LLM 生成 Matplotlib 程式碼 |
| 122 | + plot_prompt = f""" |
| 123 | + 請根據 使用者問題 '{user_input}' 以及 Pandas DataFrame `result_df` 繪製 Matplotlib 圖表: |
| 124 | + {result_df.head().to_string()} |
| 125 | + |
| 126 | + 並確保符合以下要求: |
| 127 | +
|
| 128 | + 1. **請嚴格遵守:** |
| 129 | + - **只能使用 `pd` 和 `matplotlib` 和 `plt` 套件,且不需要寫任何 import ,不允許操作其他套件。 |
| 130 | + - **禁止** 重新定義 `result_df`,不要包含 `result_df = ...` 或 `pd.DataFrame({...})`。 |
| 131 | + - **請直接使用** `result_df`,假設它已經存在且包含完整數據,不需要創建新數據。 |
| 132 | + - **不得修改 `result_df` 的數據**,只能轉換格式(如 `pd.to_datetime()`)。 |
| 133 | +
|
| 134 | + 2. **程式碼格式:** |
| 135 | + - **請使用 `fig, ax = plt.subplots()`** 來建立圖表。 |
| 136 | + - **請使用 `ax.plot()`,而不是 `plt.plot()`。** |
| 137 | + - **請確保有產生 `fig` 變數,但不需要返回 `fig` 變數,不要使用 `plt.show()`**。 |
| 138 | +
|
| 139 | + 3. **圖表設定:** |
| 140 | + - 設定適當的 `figsize`,長度設定為 8。 |
| 141 | + - 使用 `ax.grid(True)` 啟用網格線。 |
| 142 | + - 設定合適的標題與標籤,請確保 `plt` 參數有設定中文字型,才能顯示中文,使用 `plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']` 。 |
| 143 | + - **只有在資料包含時間序列時才可使用折線圖**。例如,顯示時間變化的數據,如每日銷售、溫度隨時間變化等,才能使用折線圖。 |
| 144 | + - **如果資料中沒有時間變數,請避免使用折線圖**,而應選擇其他適合的圖表類型,例如條形圖、圓餅圖等。 |
| 145 | + - 請確保根據資料的性質選擇正確的圖表。折線圖僅用於顯示時間序列數據,其他情境請選擇其他圖表形式。 |
| 146 | + - **如果 X 軸的標籤過多,請調整顯示方式**,例如將 X 軸標籤旋轉為斜體(例如 45 度角),以避免文字擠在一起。若標籤仍然過於擠密,考慮將標籤改為垂直顯示,或僅顯示部分標籤(例如每隔一個顯示一個標籤),以提高可讀性。 |
| 147 | +
|
| 148 | + **請只有 Python 程式碼,**不要包含任何說明或解釋!** |
| 149 | + **請確保程式碼沒有 `result_df = ...`,否則無效!** |
| 150 | + **請確保程式碼沒有 return 任何東西!** |
| 151 | + """ |
| 152 | + |
| 153 | + # 取得 LLM 生成的 Python 程式碼 |
| 154 | + generated_code = llm([HumanMessage(content=plot_prompt)]).content |
| 155 | + |
| 156 | + # 使用正則表達式移除 Markdown 程式碼 |
| 157 | + clean_code = re.sub(r"```(python)?\n", "", generated_code).strip() |
| 158 | + clean_code = re.sub(r"\n```", "", clean_code) |
| 159 | + |
| 160 | + plot_module = {"matplotlib": matplotlib, 'plt': plt, "pd": pd, "result_df": result_df} |
| 161 | + # 執行生成的程式碼 |
| 162 | + exec(clean_code, plot_module) |
| 163 | + |
| 164 | + # 從 `plot_module` 取出 `fig` |
| 165 | + fig = plot_module.get("fig") |
| 166 | + |
| 167 | + # 顯示圖表到 Streamlit |
| 168 | + st.pyplot(fig=fig, use_container_width=False) |
0 commit comments