Skip to content

Commit c589e4a

Browse files
committed
feat: add memory to prompt
1 parent 24ae670 commit c589e4a

File tree

5 files changed

+42
-30
lines changed

5 files changed

+42
-30
lines changed

README-CH.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ DB_URL='postgresql://user:password@host:5432/database'
4343
```bash
4444
pip install streamlit==1.42.2 pandas==2.2.3 python-dotenv==1.0.1 \
4545
langchain-community==0.3.19 langchain-openai==0.3.7 langchain==0.3.20 \
46-
langchain-ollama==0.2.3
46+
langchain-ollama==0.2.3 python-magic-bin==0.4.14
4747
```
4848

4949

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Modify the prompt in the code based on your database information. Some templates
4242
```bash
4343
pip install streamlit==1.42.2 pandas==2.2.3 python-dotenv==1.0.1 \
4444
langchain-community==0.3.19 langchain-openai==0.3.7 langchain==0.3.20 \
45-
langchain-ollama==0.2.3
45+
langchain-ollama==0.2.3 python-magic-bin==0.4.14
4646
```
4747

4848

app.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from langchain_community.utilities import SQLDatabase
77

88
from llm_util import get_llm
9-
from prompt_util import get_prompt
109
from rag_util import get_vector_store, run_rag
1110
from sql_util import clean_sql_response, convert_result_to_df
1211

@@ -51,11 +50,12 @@
5150

5251
sql_query = None
5352
query_result = None
53+
memory = []
5454

5555
for retry in range(1, MAX_RETRIES + 1):
5656
try:
5757
# 生成 SQL 查詢
58-
sql_query = run_rag(llm, vector_store, user_input, table_info)
58+
sql_query = run_rag(llm, vector_store, user_input, table_info, memory)
5959

6060
# 清理 SQL 查詢字串
6161
sql_query = clean_sql_response(sql_query)
@@ -67,18 +67,28 @@
6767
break
6868

6969
except Exception as e:
70+
error_message = str(e)
71+
memory.append({
72+
"sql": sql_query,
73+
"error": error_message,
74+
})
75+
7076
if retry < MAX_RETRIES:
7177
with st.chat_message("assistant"):
7278
st.markdown(
73-
f"⚠️ SQL 執行失敗:`{sql_query}`,正在重新嘗試 ({retry}/{MAX_RETRIES})...")
79+
f"❌ SQL 執行失敗:\n```sql\n{sql_query}\n```\n"
80+
f"❗ 錯誤訊息:`{error_message}`,正在重新嘗試 ({retry}/{MAX_RETRIES})...")
7481
else:
7582
with st.chat_message("assistant"):
76-
st.markdown(f"❌ SQL 執行失敗:{e}")
83+
st.markdown(
84+
f"❌ SQL 執行失敗:\n```sql\n{sql_query}\n```\n"
85+
f"❗ 錯誤訊息:`{error_message}`")
7786

7887
# 移除錯誤的 SQL 語法
7988
sql_query = None
8089
break
8190

91+
# 順利產生 sql 後嘗試執行
8292
if sql_query:
8393

8494
with st.chat_message("assistant"):

prompt_util.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44

5-
def get_prompt(example: str | None):
5+
def get_prompt(example: str | None, memory: str | None):
66
"""
77
自訂產生 SQL 的 Prompt
88
"""
@@ -19,31 +19,24 @@ def get_prompt(example: str | None):
1919
### 1. SQL 語法規則
2020
2121
- **表格名稱必須加上雙引號**
22-
- ✅ 正確:`SELECT 名稱 FROM "table_name"`
23-
- ❌ 錯誤:`SELECT 名稱 FROM table_name`(**表格名稱沒有雙引號,錯誤!**)
24-
- ❌ 錯誤:`SELECT 名稱 FROM 'table_name'`(**單引號錯誤!**)
25-
26-
- **欄位名稱不可使用雙引號**
27-
- ✅ 正確:`SELECT 名稱 FROM "table_name"`
28-
- ❌ 錯誤:`SELECT "名稱" FROM "table_name"`(**欄位名稱不應加雙引號!**)
22+
- **欄位名稱不可使用任何符號**
23+
- 範例:`SELECT 名稱 FROM "table_name"`
2924
3025
- **聚合函數必須加括號**
31-
- ✅ 正確:`SELECT MIN(日期) FROM "table_name"`
32-
- ❌ 錯誤:`SELECT MIN 日期 FROM "table_name"`
26+
- 範例:`SELECT MIN(日期) FROM "table_name"`
3327
3428
### 2. `GROUP BY` 使用規則
3529
3630
- **當問題涉及「最高」「最低」「平均」「總和」時,一定要 `GROUP BY`**
3731
- **當問題要求「每個對象」、「每個類別」、「每個項目」時,一定要 `GROUP BY`**
38-
- ❌ 錯誤:`SELECT AVG(數值) FROM "table_name"`
39-
- ✅ 正確:`SELECT 類別, AVG(數值) FROM "table_name" GROUP BY 類別`
32+
- 範例:`SELECT 類別, AVG(數值) FROM "table_name" GROUP BY 類別`
33+
4034
- **選擇正確的 `GROUP BY` 屬性**
4135
- 若問題涉及某項目(如:「哪個項目的值最高?」)➡ `GROUP BY 項目名稱`
4236
- 若問題涉及某類別(如:「哪個類別的平均值最高?」)➡ `GROUP BY 類別名稱`
4337
- 若問題未明確說明,預設使用 `GROUP BY 項目名稱`
4438
- **`SELECT` 中必須包含 `GROUP BY` 的欄位**
45-
- ❌ 錯誤:`SELECT MAX(數值) FROM "table_name" GROUP BY 類別`
46-
- ✅ 正確:`SELECT 類別, MAX(數值) FROM "table_name" GROUP BY 類別`
39+
- 範例:`SELECT 類別, MAX(數值) FROM "table_name" GROUP BY 類別`
4740
4841
"""
4942

@@ -54,4 +47,11 @@ def get_prompt(example: str | None):
5447
{example}
5548
"""
5649

50+
if memory is not None and memory != "":
51+
prompt_template += f"""
52+
53+
以下是目前為止的嘗試及錯誤訊息:
54+
{memory}
55+
"""
56+
5757
return PromptTemplate.from_template(prompt_template)

rag_util.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
1-
import ast
21
import json
32
import os
4-
import re
53

64
from langchain_ollama import OllamaEmbeddings
7-
import pandas as pd
8-
import streamlit as st
95
from dotenv import load_dotenv
10-
import matplotlib
11-
import matplotlib.pyplot as plt
126
from langchain_community.document_loaders import DirectoryLoader
13-
from langchain.prompts import PromptTemplate
147
from langchain_core.documents import Document
158
from langchain_core.vectorstores import InMemoryVectorStore
169

@@ -52,7 +45,7 @@ def get_vector_store():
5245
return vector_store
5346

5447

55-
def run_rag(llm, vector_store, user_input, table_info):
48+
def run_rag(llm, vector_store, user_input, table_info, memory):
5649
"""
5750
執行 RAG 流程,先檢索相似內容,再生成 SQL 查詢
5851
"""
@@ -70,13 +63,22 @@ def run_rag(llm, vector_store, user_input, table_info):
7063
回答: `{best_match.metadata["response"]}`
7164
""" if best_match else None
7265

73-
prompt = get_prompt(example)
66+
# 整理記憶
67+
memory_str = ""
68+
if memory:
69+
for item in memory:
70+
memory_str += f"""
71+
SQL: {item['sql']}
72+
錯誤訊息:`{item['error']}`
73+
"""
74+
75+
76+
prompt = get_prompt(example, memory_str)
7477

7578
# 產生 SQL 查詢
7679
messages = prompt.invoke({
7780
"input": user_input,
7881
"table_info": table_info,
79-
"example": example,
8082
"top_k": 20
8183
})
8284

0 commit comments

Comments
 (0)