Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions extra/rag_agent.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,19 @@ def search_vector_db(query):
return snippets

# --- See if we can pull a city from the RAG results ---
def extract_city_from_rag(snippets):
"""Try to extract city names from RAG results"""
def extract_city_from_rag(snippets, user_input):
"""Try to extract city names from RAG results. Uses conversation history like an LLM for match validation"""
KNOWN_CITIES = ["New York", "San Francisco", "Chicago", "Austin", "Boston",
"London", "Toronto", "Tokyo", "Sydney", "Berlin"]

known_cities_lower = {city.lower(): city for city in KNOWN_CITIES}
user_input_lower = user_input.lower()

for snippet in snippets:
for city in KNOWN_CITIES:
if city.lower() in snippet.lower():
print(f"{GREEN}RAG detected city: {city}{RESET}")
return city
if user_input_lower in known_cities_lower and user_input_lower in snippet.lower():
city = known_cities_lower[user_input_lower]
print(f"{GREEN}RAG detected city: {city}{RESET}")
return city
return None

# --- As a fallback, try to pull a city via the LLM ---
Expand Down Expand Up @@ -236,7 +239,7 @@ if __name__ == "__main__":
rag_snippets = search_vector_db(user_input)

# City Detection Workflow
detected_city = extract_city_from_rag(rag_snippets)
detected_city = extract_city_from_rag(rag_snippets, user_input)
if not detected_city:
detected_city = fallback_detect_city_with_llm(user_input)

Expand Down