diff --git a/README.md b/README.md index e431066a..3a80a32c 100644 --- a/README.md +++ b/README.md @@ -224,6 +224,18 @@ This repository contains sample code demonstrating various use cases leveraging ![Screen Recording of Amazon Bedrock Video Chapter Creator POC](genai-quickstart-pocs-python/amazon-bedrock-video-chapter-creator-poc/images/demo.gif) +1. **Sales Analyst Bedrock Databricks POC** + This is sample code demonstrating the use of Amazon Bedrock and Generative AI to create an intelligent sales data analyst that uses natural language questions to query relational data stores, specifically Databricks. This example leverages the complete Northwind sample database with realistic sales scenarios containing customers, orders, and order details. + + ![Screen Recording of Sales Analyst Bedrock Databricks POC](Sales-Analyst-Bedrock-Databricks/images/demo.gif) + + +1. **Sales Analyst Bedrock Snowflake POC** + This is sample code demonstrating the use of Amazon Bedrock and Generative AI to create an intelligent sales data analyst that uses natural language questions to query relational data stores, specifically Snowflake. This example leverages the complete Northwind sample database with realistic sales scenarios and includes automatic database setup. + + ![Screen Recording of Sales Analyst Bedrock Snowflake POC](Sales-Analyst-Bedrock-Snowflake/images/demo.gif) + + ## Sample Proof of Concepts - .NET diff --git a/Sales-Analyst-Bedrock-Databricks/Apache-2.0 license b/Sales-Analyst-Bedrock-Databricks/Apache-2.0 license new file mode 100644 index 00000000..dda720f3 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/Apache-2.0 license @@ -0,0 +1,204 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (which shall not include communications that are solely written + by You). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based upon (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and derivative works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control + systems, and issue tracking systems that are managed by, or on behalf + of, the Licensor for the purpose of discussing and improving the Work, + but excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to use, reproduce, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Work, and to + permit persons to whom the Work is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Work. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, trademark, patent, + attribution and other notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright notice to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Support. When redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional support. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in comments for the + particular file format. We also recommend that a file or class name + and description of purpose be included on the same page as the + copyright notice for easier identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/README.md b/Sales-Analyst-Bedrock-Databricks/README.md new file mode 100644 index 00000000..a8887edd --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/README.md @@ -0,0 +1,277 @@ +# Amazon Bedrock & Databricks Sales Analyst POC (Text to SQL) + +## Overview of Solution + +This is sample code demonstrating the use of Amazon Bedrock and Generative AI to create an intelligent sales data analyst that uses natural language questions to query relational data stores, specifically Databricks. This example leverages the complete Northwind sample database with realistic sales scenarios containing customers, orders, and order details. + +![Sales Analyst Demo](images/demo.gif) + +## Goal of this POC +The goal of this repo is to provide users the ability to use Amazon Bedrock and generative AI to ask natural language questions about sales performance, customer behavior, and business metrics. These questions are automatically transformed into optimized SQL queries against a Databricks workspace. This repo includes intelligent context retrieval using FAISS vector store, LangGraph workflow orchestration, and complete Databricks automation. + +The architecture & flow of the POC is as follows: +![POC Architecture & Flow](images/architecture.png 'POC Architecture') + +When a user interacts with the POC, the flow is as follows: + +1. **Natural Language Query**: The user makes a request through the Streamlit interface, asking a natural language question about sales data in Databricks (`app.py`) + +2. **Query Understanding**: The natural language question is passed to Amazon Bedrock for intent analysis and query classification (`src/graph/workflow.py`) + +3. **Context Retrieval**: The system performs semantic search using FAISS vector store to retrieve relevant database schema information and table relationships (`src/vector_store/faiss_manager.py`) + +4. **Intelligent SQL Generation**: Amazon Bedrock generates optimized SQL queries using the retrieved context, ensuring proper table joins and data type handling (`src/graph/workflow.py`) + +5. **Secure Query Execution**: The SQL query is executed against the Databricks workspace through secure REST API connection (`src/utils/databricks_rest_connector.py`) + +6. **Result Analysis**: The retrieved data is passed back to Amazon Bedrock for intelligent analysis and insight generation (`src/graph/workflow.py`) + +7. **Natural Language Response**: The system returns comprehensive insights and explanations to the user through the Streamlit frontend (`app.py`) + +# How to use this Repo: + +## Prerequisites: + +1. [AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html) installed and configured with access to Amazon Bedrock. + +2. [Python](https://www.python.org/downloads/) v3.8 or greater. The POC runs on Python. + +3. Databricks account with appropriate permissions to create workspaces and warehouses. + +4. AWS account with permissions to access Amazon Bedrock services. + +## Steps + +1. Install Git (Optional step): + ```bash + # Amazon Linux / CentOS / RHEL: + sudo yum install -y git + # Ubuntu / Debian: + sudo apt-get install -y git + # Mac/Windows: Git is usually pre-installed + ``` + +2. Clone the repository to your local machine. + + ```bash + git clone https://github.com/AWS-Samples-GenAI-FSI/Sales-Analyst-Bedrock-Databricks.git + + ``` + + The file structure of this POC is organized as follows: + + * `requirements.txt` - All dependencies needed for the application + * `app.py` - Main Streamlit application with UI components + * `src/bedrock/bedrock_helper.py` - Amazon Bedrock client wrapper + * `src/graph/workflow.py` - LangGraph workflow orchestration + * `src/vector_store/faiss_manager.py` - FAISS vector store for semantic search + * `src/utils/databricks_rest_connector.py` - Databricks REST API connection management + * `src/utils/github_data_loader.py` - Automated data download from GitHub + * `src/utils/northwind_bootstrapper.py` - Automatic sample data loading + * `src/utils/databricks_workspace_manager.py` - Databricks workspace automation + * `src/monitoring/langfuse_monitor.py` - LangFuse monitoring integration + +3. Open the repository in your favorite code editor. In the terminal, navigate to the POC's folder: + ```bash + cd Sales-Analyst-Bedrock-Databricks + ``` + +4. Configure the Python virtual environment, activate it: + ```bash + python -m venv .venv + source .venv/bin/activate # On Windows: .venv\\Scripts\\activate + ``` + +5. Install project dependencies: + ```bash + pip install -r requirements.txt + ``` + +6. **Get your Databricks credentials:** + + **Workspace URL:** + - Copy the URL from your browser when logged into Databricks + - Format: `https://dbc-xxxxxxxx-xxxx.cloud.databricks.com` + + **Personal Access Token:** + - In Databricks, click your profile icon (top right) + - Go to "User Settings" + - Click "Developer" β†’ "Access Tokens" + - Click "Generate New Token" + - Give it a name like "Sales Analyst App" + - Set expiration (or leave blank for no expiration) + - Click "Generate" and **copy the token immediately** (you can't see it again!) + +7. Configure your credentials by editing the `.env` file: + + ```bash + # AWS Configuration (Required) + AWS_REGION=us-east-1 + AWS_ACCESS_KEY_ID=your_access_key_here + AWS_SECRET_ACCESS_KEY=your_secret_key_here + + # Databricks Configuration (Required) + DATABRICKS_HOST=https://dbc-xxxxxxxx-xxxx.cloud.databricks.com + DATABRICKS_TOKEN=your_databricks_token_here + DATABRICKS_CLUSTER_ID=auto_created # Optional: see step 8 + DATABRICKS_CATALOG=workspace + DATABRICKS_SCHEMA=northwind + ``` + +8. **Get your Databricks warehouse ID (Optional):** + - Go to "SQL Warehouses" in your Databricks workspace + - Click on "Serverless Starter Warehouse" (default warehouse) + - Copy the warehouse ID from the connection details tab + - Update `DATABRICKS_CLUSTER_ID` in `.env` with this ID + - Or leave as `auto_created` - the app will find it automatically + +9. Start the application from your terminal: + ```bash + streamlit run app.py + ``` + +10. **Automatic Setup**: On first run, the application will automatically: + - Create Databricks workspace (if needed) + - Launch serverless SQL warehouse with optimized configuration + - Download complete Northwind dataset from GitHub (91 customers, 830 orders, 2155 order details) + - Load data with proper relationships and foreign keys + - Initialize AI components and vector store + - This process takes approximately 5-8 minutes + +11. **Start Analyzing**: Once setup is complete, you can ask natural language questions like: + - "What are the top 5 customers by order value?" + - "Which customers haven't placed orders recently?" + - "Show me customer distribution by country" + - "What's the average order value by customer?" + - "Which products are most popular?" + - "Show me sales trends by month" + +## Architecture Highlights + +- **Complete Databricks Automation**: Automatically creates workspace, serverless warehouse, and schema +- **Context-Aware AI**: Semantic search for intelligent SQL generation using FAISS +- **Multi-Step AI Pipeline**: Query understanding β†’ Context retrieval β†’ SQL generation β†’ Analysis +- **Workflow Orchestration**: LangGraph-powered structured analysis workflow +- **GitHub Data Integration**: Automatically downloads complete Northwind dataset from GitHub +- **Zero Configuration**: Just add AWS credentials and run! + +### Built with: + +- Amazon Bedrock: AI/ML models for natural language processing +- Databricks: Unified analytics platform with Delta Lake +- FAISS: Vector database for semantic search +- Streamlit: Web interface +- LangGraph: Workflow orchestration +- LangFuse: AI monitoring and observability + +### Database Structure +After setup, you'll have access to: +- **customers** (91 records) - Customer information and demographics +- **orders** (830 records) - Order headers with dates and shipping +- **order_details** (2,155 records) - Individual line items with quantities and prices +- **products** (77 records) - Product catalog with categories and pricing +- **categories** (8 records) - Product categories and descriptions +- **suppliers** (29 records) - Supplier information and contacts +- **employees** (9 records) - Employee data and territories +- **shippers** (3 records) - Shipping company information + +## AI-Powered Workflow +The application uses **LangGraph** and **Amazon Bedrock** to create an intelligent analysis workflow: + +1. 🧠 **Understand Query**: AI analyzes your natural language question +2. πŸ” **Retrieve Context**: Finds relevant table/column metadata using FAISS vector search +3. πŸ’» **Generate SQL**: Creates optimized SQL query using context +4. ⚑ **Execute Query**: Runs SQL against your Databricks workspace +5. πŸ“Š **Analyze Results**: Provides business insights and explanations + +### Key Features +- **Natural Language to SQL**: No SQL knowledge required +- **Intelligent Context**: Understands your database schema automatically +- **Error Recovery**: Handles and recovers from query errors +- **Performance Monitoring**: Tracks AI interactions with LangFuse +- **Delta Lake Integration**: ACID transactions and time travel capabilities + +## Monitoring (Optional) + +**LangFuse Integration** provides: +- πŸ“Š AI interaction tracking +- πŸ”„ Workflow step monitoring +- 🚨 Error logging and analysis +- ⚑ Performance metrics + +To enable, update your credentials in the connector file or set environment variables. + +## Troubleshooting +### Common Issues +- **"Connection failed" errors**: + - Verify your Databricks credentials are correct + - Check your workspace URL format + - Ensure your token has appropriate permissions + - Test connection: `curl -H "Authorization: Bearer $DATABRICKS_TOKEN" "$DATABRICKS_HOST/api/2.0/clusters/list"` + +- **"Setup fails" or timeouts**: + - Check your Databricks workspace is accessible + - Verify network connectivity to Databricks + - Ensure sufficient compute resources + - Check GitHub connectivity: `curl -I https://raw.githubusercontent.com/jpwhite3/northwind-SQLite3/master/csv/customers.csv` + +- **"Credentials not found"**: + - Make sure you updated the `.env` file with your actual credentials + - Make sure `.env` file is in the same directory as `app.py` + - Verify no extra spaces in your credential values + - Check that you saved the `.env` file after editing + +- **"App won't start"**: + - Ensure Python 3.8+ is installed: `python --version` + - Install requirements: `pip install -r requirements.txt` + - Try: `python -m streamlit run app.py` + +- **"AWS Bedrock access denied"**: + - Verify your AWS credentials are configured + - Check your IAM permissions for Bedrock access + - Ensure you're in a supported AWS region + +- **"SQL Generation Problems"**: + - Ensure schema information is loaded in vector store + - Check that table names match Databricks catalog structure + - Verify column names and data types + +### Getting Help +- Check Databricks query history for detailed error messages +- Review AWS CloudWatch logs for Bedrock API calls +- Enable debug logging: `logging.basicConfig(level=logging.DEBUG)` +- Ensure your Databricks workspace has no usage limits + +## Cost Management + +### Databricks Costs +- **Serverless Warehouse**: ~$0.70/hour when active +- **Auto-stop**: Configurable idle timeout (default: 10 minutes) +- **Community Edition**: Free tier available + +### AWS Costs +- **Bedrock API**: Pay-per-request pricing +- **Typical Usage**: $1-5/day for development + +### Cost Optimization Tips +- Use Community Edition for learning/testing +- Configure auto-stop for warehouses +- Monitor Bedrock API usage +- Use smaller warehouse sizes for development + +## Cleanup + +**To avoid ongoing costs, clean up demo resources when done:** + +```bash +python3 cleanup.py +``` + +**This will remove:** +- All Northwind tables and schema +- Custom warehouse (if created) +- Local cache files +- Preserves default "Serverless Starter Warehouse" + +## How-To Guide +For detailed usage instructions and advanced configuration, visit the application's help section within the Streamlit interface. \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/README_aligned.md b/Sales-Analyst-Bedrock-Databricks/README_aligned.md new file mode 100644 index 00000000..b8c9bb23 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/README_aligned.md @@ -0,0 +1,241 @@ +# Amazon Bedrock & Databricks Sales Analyst POC (Text to SQL) + +## Overview of Solution + +This is sample code demonstrating the use of Amazon Bedrock and Generative AI to create an intelligent sales data analyst that uses natural language questions to query relational data stores, specifically Databricks. This example leverages the complete Northwind sample database with realistic sales scenarios containing customers, orders, and order details. + +![Sales Analyst Demo](assets/images/demo.png) + +## Goal of this POC +The goal of this repo is to provide users the ability to use Amazon Bedrock and generative AI to ask natural language questions about sales performance, customer behavior, and business metrics. These questions are automatically transformed into optimized SQL queries against a Databricks workspace. This repo includes intelligent context retrieval using FAISS vector store, LangGraph workflow orchestration, and complete Databricks automation. + +The architecture & flow of the POC is as follows: +![POC Architecture & Flow](images/architecture.png 'POC Architecture') + +When a user interacts with the POC, the flow is as follows: + +1. **Natural Language Query**: The user makes a request through the Streamlit interface, asking a natural language question about sales data in Databricks (`app.py`) + +2. **Query Understanding**: The natural language question is passed to Amazon Bedrock for intent analysis and query classification (`src/graph/workflow.py`) + +3. **Context Retrieval**: The system performs semantic search using FAISS vector store to retrieve relevant database schema information and table relationships (`src/vector_store/faiss_manager.py`) + +4. **Intelligent SQL Generation**: Amazon Bedrock generates optimized SQL queries using the retrieved context, ensuring proper table joins and data type handling (`src/graph/workflow.py`) + +5. **Secure Query Execution**: The SQL query is executed against the Databricks workspace through secure REST API connection (`src/utils/databricks_rest_connector.py`) + +6. **Result Analysis**: The retrieved data is passed back to Amazon Bedrock for intelligent analysis and insight generation (`src/graph/workflow.py`) + +7. **Natural Language Response**: The system returns comprehensive insights and explanations to the user through the Streamlit frontend (`app.py`) + +# How to use this Repo: + +## Prerequisites: + +1. [AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html) installed and configured with access to Amazon Bedrock. + +2. [Python](https://www.python.org/downloads/) v3.8 or greater. The POC runs on Python. + +3. Databricks account with appropriate permissions to create workspaces and warehouses. + +4. AWS account with permissions to access Amazon Bedrock services. + +## Steps + +1. Install Git (Optional step): + ```bash + # Amazon Linux / CentOS / RHEL: + sudo yum install -y git + # Ubuntu / Debian: + sudo apt-get install -y git + # Mac/Windows: Git is usually pre-installed + ``` + +2. Clone the repository to your local machine. + + ```bash + git clone https://github.com/AWS-Samples-GenAI-FSI/Sales-Analyst-Bedrock-Databricks.git + + ``` + + The file structure of this POC is organized as follows: + + * `requirements.txt` - All dependencies needed for the application + * `app.py` - Main Streamlit application with UI components + * `src/bedrock/bedrock_helper.py` - Amazon Bedrock client wrapper + * `src/graph/workflow.py` - LangGraph workflow orchestration + * `src/vector_store/faiss_manager.py` - FAISS vector store for semantic search + * `src/utils/databricks_rest_connector.py` - Databricks REST API connection management + * `src/utils/github_data_loader.py` - Automated data download from GitHub + * `src/utils/northwind_bootstrapper.py` - Automatic sample data loading + * `src/utils/databricks_workspace_manager.py` - Databricks workspace automation + * `src/monitoring/langfuse_monitor.py` - LangFuse monitoring integration + +3. Open the repository in your favorite code editor. In the terminal, navigate to the POC's folder: + ```bash + cd Sales-Analyst-Bedrock-Databricks + ``` + +4. Configure the Python virtual environment, activate it: + ```bash + python -m venv .venv + source .venv/bin/activate # On Windows: .venv\Scripts\activate + ``` + +5. Install project dependencies: + ```bash + pip install -r requirements.txt + ``` + +6. Configure your credentials by editing the `.env` file and replacing the dummy values with your actual credentials: + + ```bash + # AWS Configuration (Required) + AWS_REGION=us-east-1 + AWS_ACCESS_KEY_ID=your_access_key_here + AWS_SECRET_ACCESS_KEY=your_secret_key_here + + # Databricks Configuration (Auto-configured) + DATABRICKS_HOST=localhost + DATABRICKS_TOKEN=auto_generated + DATABRICKS_CLUSTER_ID=auto_created + DATABRICKS_CATALOG=sales_analyst + DATABRICKS_SCHEMA=northwind + ``` + +7. Start the application from your terminal: + ```bash + streamlit run app.py + ``` + +8. **Automatic Setup**: On first run, the application will automatically: + - Create Databricks workspace (if needed) + - Launch serverless SQL warehouse with optimized configuration + - Download complete Northwind dataset from GitHub (91 customers, 830 orders, 2155 order details) + - Load data with proper relationships and foreign keys + - Initialize AI components and vector store + - This process takes approximately 5-8 minutes + +9. **Start Analyzing**: Once setup is complete, you can ask natural language questions like: + - "What are the top 5 customers by order value?" + - "Which customers haven't placed orders recently?" + - "Show me customer distribution by country" + - "What's the average order value by customer?" + - "Which products are most popular?" + - "Show me sales trends by month" + +## Architecture Highlights + +- **Complete Databricks Automation**: Automatically creates workspace, serverless warehouse, and schema +- **Context-Aware AI**: Semantic search for intelligent SQL generation using FAISS +- **Multi-Step AI Pipeline**: Query understanding β†’ Context retrieval β†’ SQL generation β†’ Analysis +- **Workflow Orchestration**: LangGraph-powered structured analysis workflow +- **GitHub Data Integration**: Automatically downloads complete Northwind dataset from GitHub +- **Zero Configuration**: Just add AWS credentials and run! + +### Built with: + +- Amazon Bedrock: AI/ML models for natural language processing +- Databricks: Unified analytics platform with Delta Lake +- FAISS: Vector database for semantic search +- Streamlit: Web interface +- LangGraph: Workflow orchestration +- LangFuse: AI monitoring and observability + +### Database Structure +After setup, you'll have access to: +- **customers** (91 records) - Customer information and demographics +- **orders** (830 records) - Order headers with dates and shipping +- **order_details** (2,155 records) - Individual line items with quantities and prices +- **products** (77 records) - Product catalog with categories and pricing +- **categories** (8 records) - Product categories and descriptions +- **suppliers** (29 records) - Supplier information and contacts +- **employees** (9 records) - Employee data and territories +- **shippers** (3 records) - Shipping company information + +## AI-Powered Workflow +The application uses **LangGraph** and **Amazon Bedrock** to create an intelligent analysis workflow: + +1. 🧠 **Understand Query**: AI analyzes your natural language question +2. πŸ” **Retrieve Context**: Finds relevant table/column metadata using FAISS vector search +3. πŸ’» **Generate SQL**: Creates optimized SQL query using context +4. ⚑ **Execute Query**: Runs SQL against your Databricks workspace +5. πŸ“Š **Analyze Results**: Provides business insights and explanations + +### Key Features +- **Natural Language to SQL**: No SQL knowledge required +- **Intelligent Context**: Understands your database schema automatically +- **Error Recovery**: Handles and recovers from query errors +- **Performance Monitoring**: Tracks AI interactions with LangFuse +- **Delta Lake Integration**: ACID transactions and time travel capabilities + +## Monitoring (Optional) + +**LangFuse Integration** provides: +- πŸ“Š AI interaction tracking +- πŸ”„ Workflow step monitoring +- 🚨 Error logging and analysis +- ⚑ Performance metrics + +To enable, update your credentials in the connector file or set environment variables. + +## Troubleshooting +### Common Issues +- **"Connection failed" errors**: + - Verify your Databricks credentials are correct + - Check your workspace URL format + - Ensure your token has appropriate permissions + - Test connection: `curl -H "Authorization: Bearer $DATABRICKS_TOKEN" "$DATABRICKS_HOST/api/2.0/clusters/list"` + +- **"Setup fails" or timeouts**: + - Check your Databricks workspace is accessible + - Verify network connectivity to Databricks + - Ensure sufficient compute resources + - Check GitHub connectivity: `curl -I https://raw.githubusercontent.com/jpwhite3/northwind-SQLite3/master/csv/customers.csv` + +- **"Credentials not found"**: + - Make sure you updated the `.env` file with your actual credentials + - Make sure `.env` file is in the same directory as `app.py` + - Verify no extra spaces in your credential values + - Check that you saved the `.env` file after editing + +- **"App won't start"**: + - Ensure Python 3.8+ is installed: `python --version` + - Install requirements: `pip install -r requirements.txt` + - Try: `python -m streamlit run app.py` + +- **"AWS Bedrock access denied"**: + - Verify your AWS credentials are configured + - Check your IAM permissions for Bedrock access + - Ensure you're in a supported AWS region + +- **"SQL Generation Problems"**: + - Ensure schema information is loaded in vector store + - Check that table names match Databricks catalog structure + - Verify column names and data types + +### Getting Help +- Check Databricks query history for detailed error messages +- Review AWS CloudWatch logs for Bedrock API calls +- Enable debug logging: `logging.basicConfig(level=logging.DEBUG)` +- Ensure your Databricks workspace has no usage limits + +## Cost Management + +### Databricks Costs +- **Serverless Warehouse**: ~$0.70/hour when active +- **Auto-stop**: Configurable idle timeout (default: 10 minutes) +- **Community Edition**: Free tier available + +### AWS Costs +- **Bedrock API**: Pay-per-request pricing +- **Typical Usage**: $1-5/day for development + +### Cost Optimization Tips +- Use Community Edition for learning/testing +- Configure auto-stop for warehouses +- Monitor Bedrock API usage +- Use smaller warehouse sizes for development + +## How-To Guide +For detailed usage instructions and advanced configuration, visit the application's help section within the Streamlit interface. \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/app.py b/Sales-Analyst-Bedrock-Databricks/app.py new file mode 100644 index 00000000..8b0f5832 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/app.py @@ -0,0 +1,413 @@ +""" +GenAI Sales Analyst - Main application file (Databricks version). +""" +import streamlit as st +import pandas as pd +import time +import os +import numpy as np +from datetime import datetime +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Import components +from src.bedrock.bedrock_helper import BedrockHelper +from src.vector_store.faiss_manager import FAISSManager + +from src.graph.workflow import AnalysisWorkflow +from src.utils.databricks_rest_connector import DatabricksRestConnector +from src.utils.northwind_bootstrapper import bootstrap_northwind, check_northwind_exists +from src.utils.databricks_workspace_manager import create_databricks_workspace_if_needed + +def initialize_components(): + """ + Initialize application components. + + Returns: + Dictionary of initialized components + """ + # Get environment variables + aws_region = os.getenv('AWS_REGION', 'us-east-1') + s3_bucket = os.getenv('S3_BUCKET', 'your-bucket-name') + + # Initialize Bedrock client + bedrock = BedrockHelper(region_name=aws_region) + + # Initialize vector store + vector_store = FAISSManager( + bedrock_client=bedrock, + s3_bucket=s3_bucket + ) + + # No monitoring needed + monitor = None + + # Initialize workflow + workflow = AnalysisWorkflow( + bedrock_helper=bedrock, + vector_store=vector_store, + monitor=monitor + ) + + return { + 'bedrock': bedrock, + 'vector_store': vector_store, + 'monitor': monitor, + 'workflow': workflow + } + + +def load_all_metadata(vector_store, show_progress=False): + """ + Load metadata from Northwind tables. + """ + catalog = os.getenv('DATABRICKS_CATALOG', 'workspace') + schema = os.getenv('DATABRICKS_SCHEMA', 'northwind') + + # Create simple schema context for Northwind + schema_text = f""" + Catalog: {catalog}, Schema: {schema} + + Table: customers - Customer information + Columns: customerid (string), companyname (string), contactname (string), country (string) + + Table: orders - Order information + Columns: orderid (int), customerid (string), orderdate (date), freight (decimal), shipcountry (string) + + Table: order_details - Order line items + Columns: orderid (int), productid (int), unitprice (decimal), quantity (int) + + Table: products - Product catalog + Columns: productid (int), productname (string), categoryid (int), unitprice (decimal) + + Table: categories - Product categories + Columns: categoryid (int), categoryname (string), description (string) + + Table: suppliers - Supplier information + Columns: supplierid (int), companyname (string), country (string) + + Table: employees - Employee data + Columns: employeeid (int), lastname (string), firstname (string), title (string) + + Table: shippers - Shipping companies + Columns: shipperid (int), companyname (string), phone (string) + """ + + # Add to vector store + texts = [schema_text] + metadatas = [{'catalog': catalog, 'schema': schema, 'type': 'schema'}] + + # Get embeddings + embeddings = [] + for text in texts: + embedding = vector_store.bedrock_client.get_embeddings(text) + embeddings.append(embedding) + + if embeddings: + embeddings_array = np.array(embeddings).astype('float32') + if embeddings_array.ndim == 1: + embeddings_array = embeddings_array.reshape(1, -1) + + vector_store.texts = texts + vector_store.metadata = metadatas + vector_store.index.add(embeddings_array) + + if show_progress: + st.sidebar.success(f"βœ… Loaded Northwind schema metadata") + + return pd.DataFrame({'schema': [schema], 'loaded': [True]}) + + return None + + +def main(): + """ + Main application function. + """ + # Set page config + st.set_page_config( + page_title="Sales Data Analyst (Databricks)", + page_icon="πŸ“Š", + layout="wide" + ) + + # Hide Streamlit branding + hide_streamlit_style = """ + + """ + st.markdown(hide_streamlit_style, unsafe_allow_html=True) + + # Custom CSS + st.markdown(""" + + """, unsafe_allow_html=True) + + # Header + st.markdown('

Sales Data Analyst

', unsafe_allow_html=True) + st.markdown('

(Powered by Amazon Bedrock and Databricks)

', unsafe_allow_html=True) + st.markdown('
', unsafe_allow_html=True) + + # Check Databricks configuration + if not os.getenv('DATABRICKS_HOST') or not os.getenv('DATABRICKS_TOKEN'): + st.error("❌ Missing Databricks configuration. Please add DATABRICKS_HOST and DATABRICKS_TOKEN to your .env file.") + st.stop() + + # Initialize components + components = initialize_components() + + # Setup Databricks workspace + try: + with st.spinner("Setting up Databricks workspace..."): + # Initialize REST connector + db_connector = DatabricksRestConnector() + + # Test connection + test_result = db_connector.execute_query("SELECT 1 as test") + if test_result: + st.sidebar.success("βœ… Connected to Databricks") + else: + st.sidebar.error("❌ Failed to connect to Databricks") + return + + # Always do fresh setup - no checks + if 'database_setup_complete' not in st.session_state: + progress_bar = st.progress(0) + status_text = st.empty() + + def update_progress(progress, message): + progress_bar.progress(progress) + status_text.text(f"πŸš€ {message}") + + success = bootstrap_northwind( + show_progress=False, + fresh_start=True, + progress_callback=update_progress + ) + + if success: + time.sleep(1) + progress_bar.empty() + status_text.empty() + st.sidebar.success("βœ… Database ready") + st.session_state.database_setup_complete = True + st.session_state.metadata_loaded = False + else: + status_text.error("❌ Setup failed") + return + else: + st.sidebar.success("βœ… Database ready") + + + + except Exception as e: + st.sidebar.error(f"❌ Databricks connection failed: {str(e)}") + return + + # Load metadata + if 'metadata_loaded' not in st.session_state or not st.session_state.metadata_loaded: + try: + metadata_df = load_all_metadata(components['vector_store'], show_progress=True) + if metadata_df is not None and len(metadata_df) > 0: + st.session_state.metadata_df = metadata_df + st.session_state.metadata_loaded = True + st.session_state.metadata_count = len(metadata_df) + st.sidebar.success(f"βœ… Loaded metadata for {len(metadata_df)} schemas") + else: + st.sidebar.warning("⚠️ No metadata loaded - database may still be setting up") + st.session_state.metadata_loaded = False + except Exception as e: + st.sidebar.error(f"❌ Error loading metadata: {str(e)}") + st.session_state.metadata_loaded = False + else: + st.sidebar.success("βœ… Metadata ready") + + # Sidebar + with st.sidebar: + st.header("Settings") + + + + # Workflow status + if components['workflow']: + st.success("βœ… Analysis workflow enabled") + + # Reload metadata button + if st.button("πŸ”„ Reload Metadata", key="reload_metadata"): + with st.spinner("Reloading database metadata..."): + st.session_state.metadata_loaded = False + metadata_df = load_all_metadata(components['vector_store'], show_progress=True) + if metadata_df is not None and len(metadata_df) > 0: + st.session_state.metadata_df = metadata_df + st.session_state.metadata_loaded = True + st.session_state.metadata_count = len(metadata_df) + st.success(f"βœ… Reloaded metadata for {len(metadata_df)} schemas") + st.rerun() + else: + st.error("❌ Failed to reload metadata") + + # Available data section moved to sidebar + st.header("πŸ“‹ Available Data") + st.markdown(""" + **🏒 Business Data:** + - πŸ‘₯ **Customers** - Company details, contacts, locations + - πŸ“¦ **Orders** - Order dates, shipping info, freight costs + - πŸ›οΈ **Order Details** - Products, quantities, prices, discounts + + **🏭 Product Catalog:** + - 🎯 **Products** - Names, prices, stock levels + - πŸ“‚ **Categories** - Product groupings and descriptions + - 🚚 **Suppliers** - Vendor information and contacts + + **πŸ‘¨πŸ’Ό Operations:** + - πŸ‘” **Employees** - Staff details and hierarchy + - πŸš› **Shippers** - Delivery companies and contacts + """) + + # Show available catalogs and schemas + with st.expander("Database Explorer", expanded=False): + if st.button("Show Catalogs"): + try: + catalogs = get_available_catalogs() + st.write("Available catalogs:") + st.write(", ".join(catalogs)) + except Exception as e: + st.error(f"Error listing catalogs: {str(e)}") + + # Main content area + col1 = st.container() + + with col1: + st.markdown('

Ask questions about your sales data

', unsafe_allow_html=True) + st.markdown('

You can ask about customer orders, product sales, and more.

', unsafe_allow_html=True) + + # Examples + with st.expander("πŸ’‘ Example questions", expanded=False): + st.markdown(""" + **βœ… Try these working questions:** + + 1. **What are the top 10 customers by total order value?** + 2. **Which products generate the most revenue?** + 3. **What's the average order value by country?** + 4. **Which product categories sell the most?** + 5. **What are the top 5 most expensive products?** + 6. **How many orders come from each country?** + 7. **Which countries have the highest average order values?** + 8. **Who are our most frequent customers?** + 9. **Which suppliers provide the most products?** + 10. **Which employees process the most orders?** + """) + + # Question input + question = st.text_input( + "Ask your question:", + placeholder="e.g., What are the top 5 customers by order value?" + ) + + # Process question + if question: + if 'metadata_df' not in st.session_state or not st.session_state.get('metadata_loaded', False): + st.error("Metadata not loaded. Please click 'Reload Metadata' button in the sidebar.") + return + + try: + # Execute workflow + with st.spinner("Processing your question..."): + result = components['workflow'].execute(question, db_connector.execute_query) + + # Display workflow steps + with st.expander("Workflow Steps", expanded=False): + steps = result.get("steps_completed", []) + for step in steps: + if "error" in step: + st.markdown(f'
{step}
', unsafe_allow_html=True) + else: + st.markdown(f'
{step}
', unsafe_allow_html=True) + + # Display error if any + if "error" in result: + st.error(result.get("friendly_error", result["error"])) + + # Display SQL if generated + if "generated_sql" in result: + with st.expander("Generated SQL", expanded=True): + st.code(result["generated_sql"], language="sql") + + # Display results if available + if "query_results" in result: + st.write(f"Query executed in {result.get('execution_time', 0):.2f} seconds, returned {len(result['query_results'])} rows") + with st.expander("Query Results", expanded=True): + st.dataframe(result["query_results"]) + + # Display analysis + if "analysis" in result: + st.subheader("Analysis") + st.write(result["analysis"]) + + # Save to history + if 'history' not in st.session_state: + st.session_state.history = [] + + st.session_state.history.append({ + 'question': question, + 'sql': result.get('generated_sql', ''), + 'results': result.get('query_results', [])[:10], + 'analysis': result.get('analysis', ''), + 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + }) + + except Exception as e: + st.error(f"Error: {str(e)}") + + # Show history + if 'history' in st.session_state and st.session_state.history: + with st.expander("Query History", expanded=False): + for i, item in enumerate(reversed(st.session_state.history[-5:])): + st.write(f"**{item['timestamp']}**: {item['question']}") + if st.button(f"Show details", key=f"history_{i}"): + st.code(item['sql'], language="sql") + st.dataframe(item['results']) + st.write(item['analysis']) + st.divider() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/cleanup.py b/Sales-Analyst-Bedrock-Databricks/cleanup.py new file mode 100755 index 00000000..fa8fc157 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/cleanup.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +""" +Cleanup script for Sales Analyst Databricks demo. +Removes all created resources to avoid ongoing costs. +""" +import os +from dotenv import load_dotenv +from src.utils.databricks_rest_connector import DatabricksRestConnector +import requests + +load_dotenv() + +def cleanup_databricks_resources(): + """Clean up all Databricks resources created by the demo.""" + try: + host = os.getenv('DATABRICKS_HOST', '').rstrip('/') + token = os.getenv('DATABRICKS_TOKEN', '') + catalog = os.getenv('DATABRICKS_CATALOG', 'workspace') + schema = os.getenv('DATABRICKS_SCHEMA', 'northwind') + + if not host or not token: + print("❌ Missing Databricks credentials in .env file") + return False + + headers = { + 'Authorization': f'Bearer {token}', + 'Content-Type': 'application/json' + } + + connector = DatabricksRestConnector() + + print("🧹 Starting cleanup of Databricks resources...") + + # 1. Drop schema and all tables + print(f"πŸ—‘οΈ Dropping schema {catalog}.{schema}...") + try: + connector.execute_query(f"DROP SCHEMA IF EXISTS {catalog}.{schema} CASCADE") + print(f"βœ… Dropped schema {catalog}.{schema}") + except Exception as e: + print(f"⚠️ Could not drop schema: {e}") + + # 2. Stop and delete custom warehouse (if created) + print("πŸ›‘ Checking for custom warehouses to clean up...") + try: + response = requests.get(f"{host}/api/2.0/sql/warehouses", headers=headers) + if response.status_code == 200: + warehouses = response.json().get('warehouses', []) + for warehouse in warehouses: + if warehouse.get('name') == 'sales-analyst': + warehouse_id = warehouse['id'] + print(f"πŸ›‘ Stopping warehouse: {warehouse_id}") + + # Stop warehouse + requests.post(f"{host}/api/2.0/sql/warehouses/{warehouse_id}/stop", headers=headers) + + # Delete warehouse + response = requests.delete(f"{host}/api/2.0/sql/warehouses/{warehouse_id}", headers=headers) + if response.status_code == 200: + print(f"βœ… Deleted custom warehouse: {warehouse_id}") + else: + print(f"⚠️ Could not delete warehouse: {response.text}") + break + else: + print("ℹ️ No custom warehouses found (using default Serverless Starter Warehouse)") + except Exception as e: + print(f"⚠️ Error managing warehouses: {e}") + + # 3. Clean up local cache files + print("🧹 Cleaning up local cache files...") + cache_files = [ + 'metadata_cache.pkl', + '__pycache__', + 'src/__pycache__', + 'src/bedrock/__pycache__', + 'src/graph/__pycache__', + 'src/utils/__pycache__', + 'src/vector_store/__pycache__', + 'src/monitoring/__pycache__' + ] + + for cache_file in cache_files: + try: + if os.path.exists(cache_file): + if os.path.isfile(cache_file): + os.remove(cache_file) + else: + import shutil + shutil.rmtree(cache_file) + print(f"βœ… Removed {cache_file}") + except Exception as e: + print(f"⚠️ Could not remove {cache_file}: {e}") + + print("\nβœ… Cleanup completed!") + print("\nResources cleaned up:") + print(f" β€’ Schema: {catalog}.{schema} (and all tables)") + print(" β€’ Custom warehouse: sales-analyst (if existed)") + print(" β€’ Local cache files") + print("\nNote: Default 'Serverless Starter Warehouse' was preserved") + print("πŸ’° This should eliminate ongoing Databricks costs from the demo") + + return True + + except Exception as e: + print(f"❌ Cleanup failed: {e}") + return False + +if __name__ == "__main__": + print("🧹 Sales Analyst Databricks Cleanup") + print("=" * 40) + + confirm = input("This will delete all demo data and resources. Continue? (y/N): ") + if confirm.lower() in ['y', 'yes']: + cleanup_databricks_resources() + else: + print("Cleanup cancelled.") \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/images/.gitkeep b/Sales-Analyst-Bedrock-Databricks/images/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/Sales-Analyst-Bedrock-Databricks/images/architecture.png b/Sales-Analyst-Bedrock-Databricks/images/architecture.png new file mode 100644 index 00000000..36af740f Binary files /dev/null and b/Sales-Analyst-Bedrock-Databricks/images/architecture.png differ diff --git a/Sales-Analyst-Bedrock-Databricks/images/demo.gif b/Sales-Analyst-Bedrock-Databricks/images/demo.gif new file mode 100644 index 00000000..94985c68 Binary files /dev/null and b/Sales-Analyst-Bedrock-Databricks/images/demo.gif differ diff --git a/Sales-Analyst-Bedrock-Databricks/requirements.txt b/Sales-Analyst-Bedrock-Databricks/requirements.txt new file mode 100644 index 00000000..8a633f2b --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/requirements.txt @@ -0,0 +1,10 @@ +streamlit>=1.24.0 +pandas>=1.5.3 +numpy>=1.24.3 +boto3>=1.28.0 +databricks-sql-connector>=2.9.0 +databricks-cli>=0.18.0 +faiss-cpu>=1.7.4 +python-dotenv>=1.0.0 +langfuse>=3.0.1 +requests>=2.31.0 \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/setup.py b/Sales-Analyst-Bedrock-Databricks/setup.py new file mode 100644 index 00000000..21e730c6 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/setup.py @@ -0,0 +1,7 @@ +""" +Setup script for the GenAI Sales Analyst application. +""" +from src.utils.setup_utils import run_setup + +if __name__ == "__main__": + run_setup() \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/__init__.py b/Sales-Analyst-Bedrock-Databricks/src/__init__.py new file mode 100644 index 00000000..06d9edf8 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/__init__.py @@ -0,0 +1 @@ +# Sales Analyst Databricks Package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/bedrock/__init__.py b/Sales-Analyst-Bedrock-Databricks/src/bedrock/__init__.py new file mode 100644 index 00000000..5989b81e --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/bedrock/__init__.py @@ -0,0 +1 @@ +# Bedrock package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/bedrock/bedrock_helper.py b/Sales-Analyst-Bedrock-Databricks/src/bedrock/bedrock_helper.py new file mode 100644 index 00000000..8f9c2475 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/bedrock/bedrock_helper.py @@ -0,0 +1,85 @@ +""" +Amazon Bedrock helper for the GenAI Sales Analyst application. +""" +import boto3 +import json +from typing import List, Dict, Any, Optional + + +class BedrockHelper: + """ + Helper class for Amazon Bedrock operations. + """ + + def __init__(self, region_name: str = 'us-east-1'): + """ + Initialize the Bedrock helper. + + Args: + region_name: AWS region name + """ + self.bedrock_runtime = boto3.client( + service_name='bedrock-runtime', + region_name=region_name + ) + + def invoke_model(self, + prompt: str, + model_id: str = "anthropic.claude-3-sonnet-20240229-v1:0", + max_tokens: int = 4096, + temperature: float = 0.7) -> str: + """ + Invoke a Bedrock model with a prompt. + + Args: + prompt: Input prompt text + model_id: Bedrock model ID + max_tokens: Maximum tokens to generate + temperature: Temperature for generation + + Returns: + Model response text + """ + body = json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": max_tokens, + "temperature": temperature, + "messages": [ + { + "role": "user", + "content": prompt + } + ] + }) + + try: + response = self.bedrock_runtime.invoke_model( + modelId=model_id, + body=body + ) + response_body = json.loads(response['body'].read()) + return response_body['content'][0]['text'] + except Exception as e: + print(f"Error invoking Bedrock: {str(e)}") + raise + + def get_embeddings(self, text: str) -> List[float]: + """ + Get embeddings for a text using Bedrock. + + Args: + text: Input text + + Returns: + List of embedding values + """ + try: + response = self.bedrock_runtime.invoke_model( + modelId="amazon.titan-embed-text-v1", + body=json.dumps({"inputText": text}) + ) + response_body = json.loads(response['body'].read()) + return response_body['embedding'] + except Exception as e: + print(f"Error getting embeddings: {str(e)}") + raise \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/graph/__init__.py b/Sales-Analyst-Bedrock-Databricks/src/graph/__init__.py new file mode 100644 index 00000000..93b88780 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/graph/__init__.py @@ -0,0 +1 @@ +# Graph package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/graph/workflow.py b/Sales-Analyst-Bedrock-Databricks/src/graph/workflow.py new file mode 100644 index 00000000..777aa946 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/graph/workflow.py @@ -0,0 +1,427 @@ +""" +LangGraph workflow for the GenAI Sales Analyst application (Databricks version). +""" +from typing import Dict, Any, List, Tuple +import json +from datetime import datetime + + +class AnalysisWorkflow: + """ + LangGraph workflow for sales data analysis with Databricks. + """ + + def __init__(self, bedrock_helper, vector_store, monitor=None): + """ + Initialize the analysis workflow. + + Args: + bedrock_helper: Client for Amazon Bedrock API + vector_store: Vector store for similarity search + monitor: Optional monitoring client + """ + self.bedrock = bedrock_helper + self.vector_store = vector_store + self.monitor = monitor + + def understand_query(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Understand and classify the user query. + """ + query = state['query'] + + prompt = f"""Analyze this query and classify it: + +Query: {query} + +Determine: +1. Query type (analysis/sql/metadata/comparison) +2. Required data sources or tables +3. Time frame mentioned (if any) +4. Specific metrics requested (if any) + +Return as JSON with these fields. +""" + + try: + response = self.bedrock.invoke_model(prompt) + + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=prompt, + response=response, + metadata={ + "step_name": "understand_query", + "query": query + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + try: + analysis = json.loads(response) + except json.JSONDecodeError: + analysis = { + "type": "analysis", + "data_sources": [], + "time_frame": "not specified", + "metrics": [] + } + + return { + **state, + "query_analysis": analysis, + "steps_completed": state.get("steps_completed", []) + ["understand_query"] + } + except Exception as e: + return { + **state, + "error": f"Error in understand_query: {str(e)}", + "steps_completed": state.get("steps_completed", []) + ["understand_query_error"] + } + + def retrieve_context(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Retrieve relevant context from vector store. + """ + if "error" in state: + return state + + query = state['query'] + + try: + similar_docs = self.vector_store.similarity_search(query, k=5) + + if not similar_docs: + if "schema" in query.lower() and "customer" in query.lower(): + return { + **state, + "generated_sql": "DESCRIBE workspace.northwind.customers;", + "skip_context": True, + "steps_completed": state.get("steps_completed", []) + ["retrieve_context", "direct_sql"] + } + else: + return { + **state, + "relevant_context": [{ + "text": "Use workspace.northwind schema with Delta tables: customers, orders, order_details, products, categories, suppliers, employees, shippers" + }], + "steps_completed": state.get("steps_completed", []) + ["retrieve_context", "fallback_context"] + } + + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=query, + response=str(similar_docs), + metadata={ + "step_name": "retrieve_context", + "num_results": len(similar_docs) + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + return { + **state, + "relevant_context": similar_docs, + "steps_completed": state.get("steps_completed", []) + ["retrieve_context"] + } + except Exception as e: + return { + **state, + "error": f"Error in retrieve_context: {str(e)}", + "steps_completed": state.get("steps_completed", []) + ["retrieve_context_error"] + } + + def generate_sql(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Generate SQL based on the query and context for Databricks. + """ + if "error" in state: + return state + + if "skip_context" in state and state["skip_context"] and "generated_sql" in state: + return state + + query = state['query'] + context = state.get('relevant_context', []) + + context_str = "\n".join([f"- {doc['text']}" for doc in context]) + + if not context_str: + prompt = f"""Generate a SQL query to answer this question for Databricks: + +Question: {query} + +Use the workspace.northwind catalog with these Delta tables: +- workspace.northwind.categories: categoryid (BIGINT), categoryname (STRING), description (STRING) +- workspace.northwind.customers: customerid (STRING), companyname (STRING), contactname (STRING), country (STRING) +- workspace.northwind.employees: employeeid (BIGINT), lastname (STRING), firstname (STRING), title (STRING) +- workspace.northwind.products: productid (BIGINT), productname (STRING), supplierid (BIGINT), categoryid (BIGINT), unitprice (DOUBLE) +- workspace.northwind.suppliers: supplierid (BIGINT), companyname (STRING), country (STRING) +- workspace.northwind.shippers: shipperid (BIGINT), companyname (STRING), phone (STRING) +- workspace.northwind.orders: orderid (BIGINT), customerid (STRING), employeeid (BIGINT), orderdate (STRING), freight (DOUBLE), shipcountry (STRING) +- workspace.northwind.order_details: orderid (BIGINT), productid (BIGINT), unitprice (DOUBLE), quantity (BIGINT), discount (DOUBLE) + +NO CAST operations needed - all numeric columns have proper types + +IMPORTANT SQL RULES for Databricks: +1. Always use catalog.schema.table format (e.g., workspace.northwind.customers) +2. Use lowercase table and column names +3. Do NOT nest aggregate functions (AVG, SUM, COUNT, etc.) +4. Use subqueries or CTEs for complex calculations +5. ALL COLUMNS ARE STRING TYPE - use CAST for calculations: CAST(unitprice AS DECIMAL(10,2)) * CAST(quantity AS INT) +6. Generate valid Databricks SQL syntax +7. Use LIMIT instead of TOP for row limiting + +For "average order value by customer" type queries, use this pattern: +SELECT customerid, companyname, AVG(order_total) as avg_order_value +FROM ( + SELECT c.customerid, c.companyname, o.orderid, SUM(CAST(od.unitprice AS DECIMAL(10,2)) * CAST(od.quantity AS INT)) as order_total + FROM workspace.northwind.customers c + JOIN workspace.northwind.orders o ON c.customerid = o.customerid + JOIN workspace.northwind.order_details od ON o.orderid = od.orderid + GROUP BY c.customerid, c.companyname, o.orderid +) subquery +GROUP BY customerid, companyname +ORDER BY avg_order_value DESC; + +Generate ONLY the SQL query without any explanation. +""" + else: + prompt = f"""Generate a SQL query to answer this question for Databricks: + +Question: {query} + +Relevant context: +{context_str} + +IMPORTANT SQL RULES for Databricks: +1. Always use catalog.schema.table format (e.g., workspace.northwind.customers) +2. Use lowercase table and column names +3. Do NOT nest aggregate functions (AVG, SUM, COUNT, etc.) +4. Use subqueries or CTEs for complex calculations +5. ALL COLUMNS ARE STRING TYPE - use CAST for calculations: CAST(unitprice AS DECIMAL(10,2)) * CAST(quantity AS INT) +6. Generate valid Databricks SQL syntax +7. Use LIMIT instead of TOP for row limiting + +Generate ONLY the SQL query without any explanation. +""" + + try: + sql = self.bedrock.invoke_model(prompt) + + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=prompt, + response=sql, + metadata={ + "step_name": "generate_sql", + "query": query + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + return { + **state, + "generated_sql": sql.strip(), + "steps_completed": state.get("steps_completed", []) + ["generate_sql"] + } + except Exception as e: + return { + **state, + "error": f"Error in generate_sql: {str(e)}", + "steps_completed": state.get("steps_completed", []) + ["generate_sql_error"] + } + + def analyze_results(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Analyze query results and provide an answer. + """ + if "error" in state: + return state + + query = state['query'] + sql = state.get('generated_sql', '') + results = state.get('query_results', []) + + if not results: + analysis = "No results found for this query." + + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=query, + response=analysis, + metadata={ + "step_name": "analyze_results", + "results_count": 0 + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + return { + **state, + "analysis": analysis, + "steps_completed": state.get("steps_completed", []) + ["analyze_results"] + } + + results_str = "\n".join([str(row) for row in results[:10]]) + if len(results) > 10: + results_str += f"\n... and {len(results) - 10} more rows" + + prompt = f"""Analyze these query results to answer the user's question: + +Question: {query} + +SQL Query: +{sql} + +Query Results (first 10 rows): +{results_str} + +Provide a clear, concise analysis that directly answers the question. Include key insights from the data. +""" + + try: + analysis = self.bedrock.invoke_model(prompt) + + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=prompt, + response=analysis, + metadata={ + "step_name": "analyze_results", + "results_count": len(results) + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + return { + **state, + "analysis": analysis.strip(), + "steps_completed": state.get("steps_completed", []) + ["analyze_results"] + } + except Exception as e: + return { + **state, + "error": f"Error in analyze_results: {str(e)}", + "steps_completed": state.get("steps_completed", []) + ["analyze_results_error"] + } + + def handle_error(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Handle errors in the workflow. + """ + error = state.get('error', 'Unknown error') + + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_error( + error_message=error, + metadata={ + "query": state.get('query', ''), + "steps_completed": state.get('steps_completed', []) + } + ) + except Exception as e: + print(f"Error logging error to LangFuse: {str(e)}") + + prompt = f"""An error occurred while processing this query: + +Query: {state.get('query', '')} + +Error: {error} + +Generate a user-friendly error message explaining what went wrong and suggesting how to fix it. +""" + + try: + friendly_message = self.bedrock.invoke_model(prompt) + except Exception: + friendly_message = f"Sorry, an error occurred: {error}. Please try rephrasing your question." + + return { + **state, + "error_handled": True, + "friendly_error": friendly_message.strip(), + "steps_completed": state.get("steps_completed", []) + ["handle_error"] + } + + def execute(self, query: str, execute_query_func=None) -> Dict[str, Any]: + """ + Execute the analysis workflow. + """ + trace_id = None + if self.monitor and self.monitor.enabled: + try: + from uuid import uuid4 + trace_id = f"workflow-{uuid4()}" + except Exception: + pass + + state = { + "query": query, + "timestamp": datetime.now().isoformat(), + "trace_id": trace_id, + "steps_completed": [] + } + + state = self.understand_query(state) + + if "error" not in state: + state = self.retrieve_context(state) + + if "error" not in state: + state = self.generate_sql(state) + + if "generated_sql" in state and "error" not in state and execute_query_func: + try: + start_time = datetime.now() + results = execute_query_func(state["generated_sql"]) + end_time = datetime.now() + execution_time = (end_time - start_time).total_seconds() + + state["query_results"] = results + state["execution_time"] = execution_time + + state = self.analyze_results(state) + + except Exception as e: + state["error"] = f"Error executing SQL: {str(e)}" + state = self.handle_error(state) + elif "error" in state: + state = self.handle_error(state) + + if self.monitor and self.monitor.enabled: + try: + steps = [] + for step in state.get("steps_completed", []): + step_data = {"name": step} + steps.append(step_data) + + self.monitor.log_workflow( + workflow_name="analysis_workflow", + steps=steps, + metadata={ + "query": query, + "execution_time": state.get("execution_time"), + "error": state.get("error") + } + ) + except Exception as e: + print(f"Error logging workflow to LangFuse: {str(e)}") + + return state \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/monitoring/__init__.py b/Sales-Analyst-Bedrock-Databricks/src/monitoring/__init__.py new file mode 100644 index 00000000..85af06d3 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/monitoring/__init__.py @@ -0,0 +1 @@ +# Monitoring package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/monitoring/langfuse_monitor.py b/Sales-Analyst-Bedrock-Databricks/src/monitoring/langfuse_monitor.py new file mode 100644 index 00000000..9031c7d0 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/monitoring/langfuse_monitor.py @@ -0,0 +1,133 @@ +""" +LangFuse monitoring integration for the GenAI Sales Analyst application. +""" +import uuid +from typing import Dict, Any, Optional, List +import os + + +class LangfuseMonitor: + """ + Manages LangFuse monitoring integration. + """ + + def __init__(self, public_key: str, secret_key: str, host: str = None): + """ + Initialize the LangFuse monitor. + + Args: + public_key: LangFuse public API key + secret_key: LangFuse secret API key + host: Optional LangFuse host URL + """ + try: + from langfuse import Langfuse + + langfuse_args = { + "public_key": public_key, + "secret_key": secret_key + } + + if host: + langfuse_args["host"] = host + + self.client = Langfuse(**langfuse_args) + self.enabled = True + print("LangFuse monitoring enabled") + except ImportError: + print("Langfuse package not installed. Monitoring will be disabled.") + self.enabled = False + except Exception as e: + print(f"Error initializing LangFuse: {str(e)}") + self.enabled = False + + def log_interaction(self, + prompt: str, + response: str, + metadata: Dict[str, Any], + trace_id: Optional[str] = None) -> Optional[str]: + """ + Log an interaction to LangFuse. + """ + if not self.enabled: + return None + + try: + with self.client.start_as_current_generation( + name=metadata.get("step_name", "model_call"), + model=metadata.get("model_id", "anthropic.claude-3-sonnet"), + prompt=prompt, + completion=response, + metadata=metadata + ): + pass + + self.client.flush() + current_trace_id = self.client.get_current_trace_id() + return current_trace_id + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + return None + + def log_workflow(self, + workflow_name: str, + steps: List[Dict[str, Any]], + metadata: Dict[str, Any]) -> Optional[str]: + """ + Log a complete workflow execution to LangFuse. + """ + if not self.enabled: + return None + + try: + with self.client.start_as_current_span( + name=workflow_name, + metadata=metadata + ): + for step in steps: + with self.client.start_as_current_span( + name=step.get("name", "unknown_step"), + metadata=step.get("metadata", {}) + ): + pass + + self.client.flush() + current_trace_id = self.client.get_current_trace_id() + return current_trace_id + except Exception as e: + print(f"Error logging workflow to LangFuse: {str(e)}") + return None + + def log_error(self, error_message: str, metadata: Dict[str, Any] = None) -> None: + """ + Log an error to LangFuse. + """ + if not self.enabled: + return + + try: + if metadata is None: + metadata = {} + + with self.client.start_as_current_span( + name="error", + metadata={ + **metadata, + "error_message": error_message, + "level": "error" + } + ): + pass + + trace_id = self.client.get_current_trace_id() + if trace_id: + self.client.create_score( + name="error", + value=0, + trace_id=trace_id + ) + + self.client.flush() + except Exception as e: + print(f"Error logging error to LangFuse: {str(e)}") + return \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/utils/__init__.py b/Sales-Analyst-Bedrock-Databricks/src/utils/__init__.py new file mode 100644 index 00000000..67b9db69 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/utils/__init__.py @@ -0,0 +1 @@ +# Utils package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/utils/databricks_rest_connector.py b/Sales-Analyst-Bedrock-Databricks/src/utils/databricks_rest_connector.py new file mode 100644 index 00000000..9e53316f --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/utils/databricks_rest_connector.py @@ -0,0 +1,168 @@ +""" +Databricks REST API connector (alternative to SQL connector) +""" +import os +import requests +import time +from dotenv import load_dotenv + +load_dotenv() + +class DatabricksRestConnector: + def __init__(self): + self.host = os.getenv('DATABRICKS_HOST', '').rstrip('/') + self.token = os.getenv('DATABRICKS_TOKEN', '') + self.warehouse_id = os.getenv('DATABRICKS_CLUSTER_ID', '') + + self.headers = { + 'Authorization': f'Bearer {self.token}', + 'Content-Type': 'application/json' + } + + # Auto-create warehouse if needed + if not self.warehouse_id or self.warehouse_id == 'auto_created': + self.warehouse_id = self.get_or_create_warehouse() + + def get_or_create_warehouse(self): + """Get existing warehouse or create new one""" + try: + # Check for existing warehouses + response = requests.get( + f"{self.host}/api/2.0/sql/warehouses", + headers=self.headers + ) + + if response.status_code == 200: + warehouses = response.json().get('warehouses', []) + # Look for default Serverless Starter Warehouse first + for warehouse in warehouses: + if warehouse.get('name') == 'Serverless Starter Warehouse': + print(f"Using default Serverless Starter Warehouse: {warehouse['id']}") + return warehouse['id'] + # Look for existing sales-analyst warehouse + for warehouse in warehouses: + if warehouse.get('name') == 'sales-analyst': + print(f"Using existing warehouse: {warehouse['id']}") + return warehouse['id'] + + # Create new warehouse + print("Creating new SQL warehouse...") + response = requests.post( + f"{self.host}/api/2.0/sql/warehouses", + headers=self.headers, + json={ + "name": "sales-analyst", + "cluster_size": "2X-Small", + "min_num_clusters": 1, + "max_num_clusters": 1, + "auto_stop_mins": 10, + "enable_photon": True, + "warehouse_type": "PRO", + "enable_serverless_compute": False + } + ) + + if response.status_code == 200: + warehouse_id = response.json()['id'] + print(f"Created warehouse: {warehouse_id}") + + # Wait for warehouse to be ready + self.wait_for_warehouse_ready(warehouse_id) + return warehouse_id + else: + print(f"Failed to create warehouse: {response.text}") + return None + + except Exception as e: + print(f"Error managing warehouse: {e}") + return None + + def wait_for_warehouse_ready(self, warehouse_id, max_wait=300): + """Wait for warehouse to be ready""" + print("Waiting for warehouse to be ready...") + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + response = requests.get( + f"{self.host}/api/2.0/sql/warehouses/{warehouse_id}", + headers=self.headers + ) + + if response.status_code == 200: + state = response.json().get('state') + if state == 'RUNNING': + print("Warehouse is ready!") + return True + elif state in ['STARTING', 'STOPPED']: + # Start the warehouse if stopped + requests.post( + f"{self.host}/api/2.0/sql/warehouses/{warehouse_id}/start", + headers=self.headers + ) + time.sleep(10) + else: + time.sleep(5) + else: + time.sleep(5) + + except Exception as e: + print(f"Error checking warehouse status: {e}") + time.sleep(5) + + print("Warehouse setup timed out") + return False + + def execute_query(self, query): + """Execute SQL query using REST API""" + try: + # Start query execution + response = requests.post( + f"{self.host}/api/2.0/sql/statements/", + headers=self.headers, + json={ + "warehouse_id": self.warehouse_id, + "statement": query, + "wait_timeout": "30s" + } + ) + + if response.status_code != 200: + raise Exception(f"Query failed: {response.text}") + + result = response.json() + statement_id = result.get('statement_id') + + # Poll for completion if still running + while result.get('status', {}).get('state') in ['PENDING', 'RUNNING']: + time.sleep(2) + response = requests.get( + f"{self.host}/api/2.0/sql/statements/{statement_id}", + headers=self.headers + ) + result = response.json() + + # Extract results + if result.get('status', {}).get('state') == 'SUCCEEDED': + data = result.get('result', {}).get('data_array', []) + columns = [col['name'] for col in result.get('manifest', {}).get('schema', {}).get('columns', [])] + + # Convert to list of dictionaries + return [dict(zip(columns, row)) for row in data] + else: + error_msg = result.get('status', {}).get('error', {}).get('message', 'Unknown error') + raise Exception(f"Query failed: {error_msg}") + + except Exception as e: + print(f"Error executing query: {e}") + return [] + +# Test function +def test_rest_connection(): + connector = DatabricksRestConnector() + result = connector.execute_query("SELECT 1 as test") + print(f"Test result: {result}") + return len(result) > 0 + +if __name__ == "__main__": + test_rest_connection() \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/utils/databricks_workspace_manager.py b/Sales-Analyst-Bedrock-Databricks/src/utils/databricks_workspace_manager.py new file mode 100644 index 00000000..aaa748fe --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/utils/databricks_workspace_manager.py @@ -0,0 +1,143 @@ +""" +Databricks workspace management utilities +""" +import boto3 +import time +import os +import requests +import streamlit as st + +def create_databricks_workspace_if_needed(): + """Create Databricks workspace and token automatically""" + try: + # Check if workspace already configured + if os.getenv('DATABRICKS_HOST') and os.getenv('DATABRICKS_TOKEN'): + return True + + st.info("πŸ”„ Creating new Databricks workspace in us-east-1...") + + session = boto3.Session( + aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), + aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), + region_name='us-east-1' + ) + + databricks_client = session.client('databricks') + workspace_name = "sales-analyst-workspace" + + # Create workspace + response = databricks_client.create_workspace( + WorkspaceName=workspace_name, + AwsRegion='us-east-1', + PricingTier='STANDARD', + DeploymentName=workspace_name.lower().replace('-', '') + ) + + workspace_id = response['WorkspaceId'] + + # Wait for ready status + progress_bar = st.progress(0) + status_text = st.empty() + + for i in range(20): # Max 10 minutes + status_response = databricks_client.describe_workspace(WorkspaceId=workspace_id) + status = status_response['WorkspaceStatus'] + + status_text.text(f"Workspace status: {status}") + progress_bar.progress((i + 1) / 20) + + if status == 'RUNNING': + workspace_url = status_response['WorkspaceUrl'] + st.success(f"βœ… Workspace ready: https://{workspace_url}") + + # Create token automatically + token = create_databricks_token(workspace_url) + if token: + update_env_file(workspace_url, token) + st.success("βœ… Token created and saved automatically") + st.info("πŸ”„ Please restart the app to continue") + return True + else: + st.error("❌ Failed to create token automatically") + return False + + elif status in ['FAILED', 'CANCELLED']: + st.error(f"❌ Workspace creation failed: {status}") + return False + + time.sleep(30) + + st.error("❌ Workspace creation timed out") + return False + + except Exception as e: + st.error(f"❌ Error creating workspace: {e}") + return False + +def create_databricks_token(workspace_url): + """Create a personal access token using Databricks API""" + try: + import requests + + # Use AWS credentials to authenticate with Databricks + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {os.getenv("AWS_ACCESS_KEY_ID")}:{os.getenv("AWS_SECRET_ACCESS_KEY")}' + } + + # Create token via API + token_data = { + 'comment': 'Auto-generated token for Sales Analyst app', + 'lifetime_seconds': 7776000 # 90 days + } + + response = requests.post( + f'https://{workspace_url}/api/2.0/token/create', + headers=headers, + json=token_data + ) + + if response.status_code == 200: + return response.json()['token_value'] + else: + st.error(f"Token creation failed: {response.text}") + return None + + except Exception as e: + st.error(f"Error creating token: {e}") + return None + +def update_env_file(workspace_url, token): + """Update .env file with workspace URL and token""" + try: + env_path = '.env' + env_lines = [] + + if os.path.exists(env_path): + with open(env_path, 'r') as f: + env_lines = f.readlines() + + # Update or add settings + updated_lines = [] + host_found = token_found = False + + for line in env_lines: + if line.startswith('DATABRICKS_HOST='): + updated_lines.append(f'DATABRICKS_HOST=https://{workspace_url}\n') + host_found = True + elif line.startswith('DATABRICKS_TOKEN='): + updated_lines.append(f'DATABRICKS_TOKEN={token}\n') + token_found = True + else: + updated_lines.append(line) + + if not host_found: + updated_lines.append(f'DATABRICKS_HOST=https://{workspace_url}\n') + if not token_found: + updated_lines.append(f'DATABRICKS_TOKEN={token}\n') + + with open(env_path, 'w') as f: + f.writelines(updated_lines) + + except Exception as e: + st.error(f"Could not update .env file: {e}") \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/utils/github_data_loader.py b/Sales-Analyst-Bedrock-Databricks/src/utils/github_data_loader.py new file mode 100644 index 00000000..eda49e21 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/utils/github_data_loader.py @@ -0,0 +1,206 @@ +""" +GitHub data loader for complete Northwind dataset. +""" +import requests +import pandas as pd +import io +import tempfile +import os + +def download_northwind_from_github(): + """Download complete Northwind dataset from GitHub and save to temp directory.""" + + # GitHub raw URLs for Northwind CSV files + base_url = "https://raw.githubusercontent.com/graphql-compose/graphql-compose-examples/master/examples/northwind/data/csv/" + + tables = { + 'categories': 'categories.csv', + 'customers': 'customers.csv', + 'employees': 'employees.csv', + 'order_details': 'order_details.csv', + 'orders': 'orders.csv', + 'products': 'products.csv', + 'shippers': 'shippers.csv', + 'suppliers': 'suppliers.csv' + } + + # Create temp directory + temp_dir = tempfile.mkdtemp(prefix='northwind_') + print(f"Downloading to: {temp_dir}") + + for table_name, filename in tables.items(): + print(f"Downloading {table_name}...") + + # Try primary source + try: + url = base_url + filename + response = requests.get(url, timeout=30) + if response.status_code == 200: + df = pd.read_csv(io.StringIO(response.text)) + df = normalize_column_names(df, table_name) + + # Save to CSV + csv_path = os.path.join(temp_dir, f"{table_name}.csv") + df.to_csv(csv_path, index=False) + print(f"βœ… Downloaded {table_name}: {len(df)} rows") + continue + except Exception as e: + print(f"Primary source failed for {table_name}: {e}") + + # Try alternative sources + alt_names = [ + filename, + filename.replace('_', ''), + filename.replace('_', '-'), + table_name + '.csv' + ] + + success = False + for alt_name in alt_names: + try: + alt_url = f"https://raw.githubusercontent.com/jpwhite3/northwind-SQLite3/master/csv/{alt_name}" + response = requests.get(alt_url, timeout=30) + if response.status_code == 200: + df = pd.read_csv(io.StringIO(response.text)) + df = normalize_column_names(df, table_name) + + # Save to CSV + csv_path = os.path.join(temp_dir, f"{table_name}.csv") + df.to_csv(csv_path, index=False) + print(f"βœ… Downloaded {table_name}: {len(df)} rows") + success = True + break + except: + continue + + if not success: + print(f"❌ Failed to download {table_name}, creating sample data") + df = create_sample_table_data(table_name) + csv_path = os.path.join(temp_dir, f"{table_name}.csv") + df.to_csv(csv_path, index=False) + + return temp_dir + +def create_sample_table_data(table_name): + """Create sample data for tables that couldn't be downloaded.""" + + if table_name == 'customers': + return pd.DataFrame([ + {'customerid': 'ALFKI', 'companyname': 'Alfreds Futterkiste', 'contactname': 'Maria Anders', 'country': 'Germany', 'city': 'Berlin'}, + {'customerid': 'ANATR', 'companyname': 'Ana Trujillo Emparedados', 'contactname': 'Ana Trujillo', 'country': 'Mexico', 'city': 'MΓ©xico D.F.'}, + {'customerid': 'ANTON', 'companyname': 'Antonio Moreno TaquerΓ­a', 'contactname': 'Antonio Moreno', 'country': 'Mexico', 'city': 'MΓ©xico D.F.'}, + {'customerid': 'BERGS', 'companyname': 'Berglunds snabbkΓΆp', 'contactname': 'Christina Berglund', 'country': 'Sweden', 'city': 'LuleΓ₯'}, + {'customerid': 'BLAUS', 'companyname': 'Blauer See Delikatessen', 'contactname': 'Hanna Moos', 'country': 'Germany', 'city': 'Mannheim'} + ]) + + elif table_name == 'orders': + return pd.DataFrame([ + {'orderid': 10248, 'customerid': 'ALFKI', 'orderdate': '1996-07-04', 'shipcountry': 'Germany', 'freight': 32.38}, + {'orderid': 10249, 'customerid': 'ANATR', 'orderdate': '1996-07-05', 'shipcountry': 'Mexico', 'freight': 11.61}, + {'orderid': 10250, 'customerid': 'ANTON', 'orderdate': '1996-07-08', 'shipcountry': 'Mexico', 'freight': 65.83}, + {'orderid': 10251, 'customerid': 'BERGS', 'orderdate': '1996-07-09', 'shipcountry': 'Sweden', 'freight': 41.34}, + {'orderid': 10252, 'customerid': 'BLAUS', 'orderdate': '1996-07-10', 'shipcountry': 'Germany', 'freight': 51.30} + ]) + + elif table_name == 'order_details': + return pd.DataFrame([ + {'orderid': 10248, 'productid': 11, 'unitprice': 14.0, 'quantity': 12, 'discount': 0.0}, + {'orderid': 10248, 'productid': 42, 'unitprice': 9.8, 'quantity': 10, 'discount': 0.0}, + {'orderid': 10249, 'productid': 14, 'unitprice': 18.6, 'quantity': 9, 'discount': 0.0}, + {'orderid': 10250, 'productid': 41, 'unitprice': 7.7, 'quantity': 10, 'discount': 0.0} + ]) + + elif table_name == 'products': + return pd.DataFrame([ + {'productid': 11, 'productname': 'Queso Cabrales', 'categoryid': 4, 'unitprice': 21.0}, + {'productid': 14, 'productname': 'Tofu', 'categoryid': 7, 'unitprice': 23.25}, + {'productid': 20, 'productname': 'Sir Rodneys Marmalade', 'categoryid': 3, 'unitprice': 81.0} + ]) + + elif table_name == 'categories': + return pd.DataFrame([ + {'categoryid': 1, 'categoryname': 'Beverages', 'description': 'Soft drinks, coffees, teas, beers, and ales'}, + {'categoryid': 2, 'categoryname': 'Condiments', 'description': 'Sweet and savory sauces, relishes, spreads, and seasonings'} + ]) + + elif table_name == 'suppliers': + return pd.DataFrame([ + {'supplierid': 1, 'companyname': 'Exotic Liquids', 'country': 'UK'}, + {'supplierid': 2, 'companyname': 'New Orleans Cajun Delights', 'country': 'USA'} + ]) + + elif table_name == 'employees': + return pd.DataFrame([ + {'employeeid': 1, 'lastname': 'Davolio', 'firstname': 'Nancy', 'title': 'Sales Representative'}, + {'employeeid': 2, 'lastname': 'Fuller', 'firstname': 'Andrew', 'title': 'Vice President, Sales'} + ]) + + elif table_name == 'shippers': + return pd.DataFrame([ + {'shipperid': 1, 'companyname': 'Speedy Express'}, + {'shipperid': 2, 'companyname': 'United Package'} + ]) + + return pd.DataFrame() + +def normalize_column_names(df, table_name): + """Normalize column names to match expected schema.""" + + # Common column name mappings + column_mappings = { + 'customers': { + 'CustomerID': 'customerid', + 'CompanyName': 'companyname', + 'ContactName': 'contactname', + 'Country': 'country', + 'City': 'city' + }, + 'orders': { + 'OrderID': 'orderid', + 'CustomerID': 'customerid', + 'OrderDate': 'orderdate', + 'ShipCountry': 'shipcountry', + 'Freight': 'freight' + }, + 'order_details': { + 'OrderID': 'orderid', + 'ProductID': 'productid', + 'UnitPrice': 'unitprice', + 'Quantity': 'quantity', + 'Discount': 'discount' + }, + 'products': { + 'ProductID': 'productid', + 'ProductName': 'productname', + 'CategoryID': 'categoryid', + 'UnitPrice': 'unitprice' + }, + 'categories': { + 'CategoryID': 'categoryid', + 'CategoryName': 'categoryname', + 'Description': 'description' + }, + 'suppliers': { + 'SupplierID': 'supplierid', + 'CompanyName': 'companyname', + 'Country': 'country' + }, + 'employees': { + 'EmployeeID': 'employeeid', + 'LastName': 'lastname', + 'FirstName': 'firstname', + 'Title': 'title' + }, + 'shippers': { + 'ShipperID': 'shipperid', + 'CompanyName': 'companyname' + } + } + + if table_name in column_mappings: + df = df.rename(columns=column_mappings[table_name]) + + # Convert all column names to lowercase + df.columns = df.columns.str.lower() + + return df \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/utils/northwind_bootstrapper.py b/Sales-Analyst-Bedrock-Databricks/src/utils/northwind_bootstrapper.py new file mode 100644 index 00000000..8b71fadb --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/utils/northwind_bootstrapper.py @@ -0,0 +1,249 @@ +""" +Northwind database bootstrapper for Databricks. +""" +import os +import pandas as pd +import requests +import json +from .databricks_rest_connector import DatabricksRestConnector +from .github_data_loader import download_northwind_from_github + +def disable_table_acls(): + """Disable Table ACLs on the current cluster.""" + try: + host = os.getenv('DATABRICKS_HOST') + token = os.getenv('DATABRICKS_TOKEN') + cluster_id = os.getenv('DATABRICKS_CLUSTER_ID') + + if not all([host, token, cluster_id]): + print("⚠️ Missing Databricks credentials for cluster modification") + return False + + headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'} + + # Get current cluster config + get_url = f"{host}/api/2.0/clusters/get" + response = requests.get(get_url, headers=headers, params={'cluster_id': cluster_id}) + + if response.status_code != 200: + print(f"❌ Failed to get cluster config: {response.text}") + return False + + cluster_config = response.json() + + # Modify spark config to disable Table ACLs + spark_conf = cluster_config.get('spark_conf', {}) + spark_conf['spark.databricks.acl.dfAclsEnabled'] = 'false' + cluster_config['spark_conf'] = spark_conf + + # Update cluster + edit_url = f"{host}/api/2.0/clusters/edit" + response = requests.post(edit_url, headers=headers, json=cluster_config) + + if response.status_code == 200: + print("βœ… Disabled Table ACLs - cluster will restart") + return True + else: + print(f"❌ Failed to disable Table ACLs: {response.text}") + return False + + except Exception as e: + print(f"❌ Error disabling Table ACLs: {e}") + return False + +def check_northwind_exists(): + """Check if Northwind database exists and has data.""" + try: + catalog = os.getenv('DATABRICKS_CATALOG', 'workspace') + schema = os.getenv('DATABRICKS_SCHEMA', 'default') + + # Check if customers table exists and has data + result = execute_query(f"SELECT COUNT(*) as count FROM {catalog}.{schema}.customers") + return result and result[0]['count'] > 0 + except: + return False + +def create_northwind_tables(): + """Create Northwind tables with proper Databricks types.""" + catalog = os.getenv('DATABRICKS_CATALOG', '') + schema = os.getenv('DATABRICKS_SCHEMA', 'default') + + # Handle empty catalog + if catalog: + full_schema = f"{catalog}.{schema}" + else: + full_schema = schema + connector = DatabricksRestConnector() + + table_schemas = { + 'categories': "categoryid BIGINT, categoryname STRING, description STRING", + 'customers': "customerid STRING, companyname STRING, contactname STRING, contacttitle STRING, address STRING, city STRING, region STRING, postalcode STRING, country STRING, phone STRING, fax STRING", + 'employees': "employeeid BIGINT, lastname STRING, firstname STRING, title STRING, titleofcourtesy STRING, birthdate STRING, hiredate STRING, address STRING, city STRING, region STRING, postalcode STRING, country STRING, homephone STRING, extension STRING, notes STRING, reportsto STRING", + 'products': "productid BIGINT, productname STRING, supplierid BIGINT, categoryid BIGINT, quantityperunit STRING, unitprice DOUBLE, unitsinstock BIGINT, unitsonorder BIGINT, reorderlevel BIGINT, discontinued BIGINT", + 'suppliers': "supplierid BIGINT, companyname STRING, contactname STRING, contacttitle STRING, address STRING, city STRING, region STRING, postalcode STRING, country STRING, phone STRING, fax STRING, homepage STRING", + 'shippers': "shipperid BIGINT, companyname STRING, phone STRING", + 'orders': "orderid BIGINT, customerid STRING, employeeid BIGINT, orderdate STRING, requireddate STRING, shippeddate STRING, shipvia BIGINT, freight DOUBLE, shipname STRING, shipaddress STRING, shipcity STRING, shipregion STRING, shippostalcode STRING, shipcountry STRING", + 'order_details': "orderid BIGINT, productid BIGINT, unitprice DOUBLE, quantity BIGINT, discount DOUBLE" + } + + try: + if catalog: + connector.execute_query(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}") + print(f"βœ… Created schema {catalog}.{schema}") + table_prefix = f"{catalog}.{schema}" + else: + connector.execute_query(f"CREATE SCHEMA IF NOT EXISTS {schema}") + print(f"βœ… Created schema {schema}") + table_prefix = schema + + for table, columns in table_schemas.items(): + connector.execute_query(f"CREATE TABLE IF NOT EXISTS {table_prefix}.{table} ({columns}) USING DELTA") + + print("βœ… Created Northwind tables") + return True + + except Exception as e: + print(f"❌ Error creating tables: {e}") + return False + +def load_data_to_databricks(data_path): + """Load CSV data with proper type handling.""" + catalog = os.getenv('DATABRICKS_CATALOG', '') + schema = os.getenv('DATABRICKS_SCHEMA', 'default') + + # Handle empty catalog + if catalog: + table_prefix = f"{catalog}.{schema}" + else: + table_prefix = schema + connector = DatabricksRestConnector() + + tables = ['categories', 'customers', 'employees', 'products', 'suppliers', 'shippers', 'orders', 'order_details'] + + try: + for table in tables: + csv_file = os.path.join(data_path, f"{table}.csv") + if not os.path.exists(csv_file): + print(f"⚠️ Skipping {table} - file not found") + continue + + try: + df = pd.read_csv(csv_file) + + # Hard limit columns to prevent schema mismatch + column_limits = { + 'categories': 3, 'customers': 11, 'employees': 16, + 'products': 10, 'suppliers': 12, 'shippers': 3, + 'orders': 14, 'order_details': 5 + } + + if table in column_limits: + original_cols = len(df.columns) + df = df.iloc[:, :column_limits[table]] + print(f" Limited {table} to {column_limits[table]} columns (was {original_cols})") + + # Use bulk insert with all data at once + values = [] + for _, row in df.iterrows(): + row_vals = [] + for val in row: + if pd.isna(val): + row_vals.append("NULL") + elif isinstance(val, (int, float)): + row_vals.append(str(val)) + else: + escaped = str(val).replace("'", "''") + row_vals.append(f"'{escaped}'") + values.append(f"({', '.join(row_vals)})") + + if values: + try: + # Single bulk insert + insert_sql = f"INSERT INTO {table_prefix}.{table} VALUES {', '.join(values)}" + connector.execute_query(insert_sql) + except Exception as e: + print(f" ⚠️ Failed to load {table} - continuing...") + continue + + print(f"βœ… Loaded {len(df)} rows into {table}") + except Exception as e: + print(f"❌ Failed to load {table}") + continue + + return True + + except Exception as e: + print(f"❌ Error loading data: {e}") + return False + +def drop_existing_schema(): + """Drop existing northwind schema and all tables.""" + from .databricks_rest_connector import DatabricksRestConnector + + try: + connector = DatabricksRestConnector() + catalog = os.getenv('DATABRICKS_CATALOG', 'workspace') + schema = os.getenv('DATABRICKS_SCHEMA', 'default') + + # Drop schema with CASCADE to remove all tables + connector.execute_query(f"DROP SCHEMA IF EXISTS {catalog}.{schema} CASCADE") + print(f"βœ… Dropped schema {catalog}.{schema}") + + except Exception as e: + print(f"⚠️ Error dropping schema: {e}") + +def bootstrap_northwind(show_progress=False, fresh_start=True, progress_callback=None): + """Bootstrap the complete Northwind database.""" + try: + # Skip Table ACL disabling for serverless clusters + if progress_callback: + progress_callback(0.05, "Preparing serverless environment...") + if show_progress: + print("πŸ”§ Preparing serverless environment...") + + if fresh_start: + if progress_callback: + progress_callback(0.1, "Dropping existing schema...") + if show_progress: + print("πŸ—‘οΈ Dropping existing schema for fresh start...") + drop_existing_schema() + + if progress_callback: + progress_callback(0.2, "Downloading dataset...") + if show_progress: + print("πŸ”„ Downloading Northwind dataset...") + + # Download data + data_path = download_northwind_from_github() + if not data_path: + print("❌ Failed to download Northwind data") + return False + + if progress_callback: + progress_callback(0.5, "Creating database tables...") + if show_progress: + print("πŸ”„ Creating database tables...") + + # Create tables + if not create_northwind_tables(): + return False + + if progress_callback: + progress_callback(0.7, "Loading data into tables...") + if show_progress: + print("πŸ”„ Loading data into tables...") + + # Load data + if not load_data_to_databricks(data_path): + return False + + if progress_callback: + progress_callback(1.0, "Complete!") + if show_progress: + print("βœ… Northwind database bootstrap complete!") + + return True + + except Exception as e: + print(f"❌ Bootstrap failed: {e}") + return False \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/vector_store/__init__.py b/Sales-Analyst-Bedrock-Databricks/src/vector_store/__init__.py new file mode 100644 index 00000000..e3d8eb87 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/vector_store/__init__.py @@ -0,0 +1 @@ +# Vector store package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Databricks/src/vector_store/faiss_manager.py b/Sales-Analyst-Bedrock-Databricks/src/vector_store/faiss_manager.py new file mode 100644 index 00000000..48bd50c3 --- /dev/null +++ b/Sales-Analyst-Bedrock-Databricks/src/vector_store/faiss_manager.py @@ -0,0 +1,137 @@ +""" +FAISS vector store manager for the GenAI Sales Analyst application. +""" +import faiss +import numpy as np +import pickle +import boto3 +from datetime import datetime +from typing import List, Dict, Any, Optional + + +class FAISSManager: + """ + Manages FAISS vector store operations. + """ + + def __init__(self, bedrock_client, s3_bucket: Optional[str] = None, dimension: int = 1536): + """ + Initialize the FAISS manager. + + Args: + bedrock_client: Client for Amazon Bedrock API + s3_bucket: S3 bucket name for storing indices + dimension: Dimension of the embedding vectors + """ + self.bedrock_client = bedrock_client + self.s3_bucket = s3_bucket + self.index = faiss.IndexFlatL2(dimension) + self.texts = [] + self.metadata = [] + + def add_texts(self, texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None): + """ + Add texts and their embeddings to the vector store. + + Args: + texts: List of text strings to add + metadatas: Optional list of metadata dictionaries + """ + if metadatas is None: + metadatas = [{} for _ in texts] + + embeddings = [] + for text in texts: + embedding = self.bedrock_client.get_embeddings(text) + embeddings.append(embedding) + + embeddings_array = np.array(embeddings).astype('float32') + self.index.add(embeddings_array) + self.texts.extend(texts) + self.metadata.extend(metadatas) + + def similarity_search(self, query: str, k: int = 4) -> List[Dict[str, Any]]: + """ + Search for similar texts based on the query. + + Args: + query: Query text + k: Number of results to return + + Returns: + List of dictionaries containing text, metadata, and distance + """ + # Handle empty index + if len(self.texts) == 0: + return [] + + try: + # Get query embedding + query_embedding = self.bedrock_client.get_embeddings(query) + query_array = np.array([query_embedding]).astype('float32') + + # Limit k to the number of items in the index + k = min(k, len(self.texts)) + if k == 0: + return [] + + # Search + distances, indices = self.index.search(query_array, k) + + results = [] + for i, idx in enumerate(indices[0]): + if idx < len(self.texts) and idx >= 0: + results.append({ + 'text': self.texts[idx], + 'metadata': self.metadata[idx], + 'distance': float(distances[0][i]) + }) + return results + except Exception as e: + print(f"Error in similarity search: {str(e)}") + return [] + + def save_index(self) -> str: + """ + Save the index and data to S3. + + Returns: + Message indicating where the index was saved + """ + if not self.s3_bucket: + return "No S3 bucket specified, index not saved" + + s3 = boto3.client('s3') + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + + index_bytes = faiss.serialize_index(self.index) + index_key = f'vector_store/index_{timestamp}.faiss' + s3.put_object(Bucket=self.s3_bucket, Key=index_key, Body=index_bytes) + + data = {'texts': self.texts, 'metadata': self.metadata} + data_key = f'vector_store/data_{timestamp}.pkl' + s3.put_object(Bucket=self.s3_bucket, Key=data_key, Body=pickle.dumps(data)) + + return f"Index saved: {index_key}, Data saved: {data_key}" + + def load_index(self, index_key: str, data_key: str): + """ + Load the index and data from S3. + + Args: + index_key: S3 key for the index file + data_key: S3 key for the data file + """ + if not self.s3_bucket: + raise ValueError("No S3 bucket specified") + + s3 = boto3.client('s3') + + index_response = s3.get_object(Bucket=self.s3_bucket, Key=index_key) + index_bytes = index_response['Body'].read() + self.index = faiss.deserialize_index(index_bytes) + + data_response = s3.get_object(Bucket=self.s3_bucket, Key=data_key) + data = pickle.loads(data_response['Body'].read()) + self.texts = data['texts'] + self.metadata = data['metadata'] \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/Apache-2.0 license b/Sales-Analyst-Bedrock-Snowflake/Apache-2.0 license new file mode 100644 index 00000000..dda720f3 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/Apache-2.0 license @@ -0,0 +1,204 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (which shall not include communications that are solely written + by You). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based upon (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and derivative works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control + systems, and issue tracking systems that are managed by, or on behalf + of, the Licensor for the purpose of discussing and improving the Work, + but excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to use, reproduce, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Work, and to + permit persons to whom the Work is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Work. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, trademark, patent, + attribution and other notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright notice to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Support. When redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional support. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in comments for the + particular file format. We also recommend that a file or class name + and description of purpose be included on the same page as the + copyright notice for easier identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/README.md b/Sales-Analyst-Bedrock-Snowflake/README.md new file mode 100644 index 00000000..2bffe433 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/README.md @@ -0,0 +1,224 @@ +# Amazon Bedrock & Snowflake Sales Analyst POC (Text to SQL) + +## Overview of Solution + +This is sample code demonstrating the use of Amazon Bedrock and Generative AI to create an intelligent sales data analyst that uses natural language questions to query relational data stores, specifically Snowflake. This example leverages the complete Northwind sample database with realistic sales scenarios containing customers, orders, and order details. + +![Sales Analyst Demo](images/demo.gif) + +## Goal of this POC +The goal of this repo is to provide users the ability to use Amazon Bedrock and generative AI to ask natural language questions about sales performance, customer behavior, and business metrics. These questions are automatically transformed into optimized SQL queries against a Snowflake database. This repo includes intelligent context retrieval using FAISS vector store, LangGraph workflow orchestration, and comprehensive monitoring capabilities. + +The architecture & flow of the POC is as follows: +![POC Architecture & Flow](images/architecture.png 'POC Architecture') + +When a user interacts with the POC, the flow is as follows: + +1. **Natural Language Query**: The user makes a request through the Streamlit interface, asking a natural language question about sales data in Snowflake (`app.py`) + +2. **Query Understanding**: The natural language question is passed to Amazon Bedrock for intent analysis and query classification (`src/graph/workflow.py`) + +3. **Context Retrieval**: The system performs semantic search using FAISS vector store to retrieve relevant database schema information and table relationships (`src/vector_store/faiss_manager.py`) + +4. **Intelligent SQL Generation**: Amazon Bedrock generates optimized SQL queries using the retrieved context, ensuring proper table joins and data type handling (`src/graph/workflow.py`) + +5. **Secure Query Execution**: The SQL query is executed against the Snowflake database through secure connection (`src/utils/snowflake_connector.py`) + +6. **Result Analysis**: The retrieved data is passed back to Amazon Bedrock for intelligent analysis and insight generation (`src/graph/workflow.py`) + +7. **Natural Language Response**: The system returns comprehensive insights and explanations to the user through the Streamlit frontend (`app.py`) + +# How to use this Repo: + +## Prerequisites: + +1. [AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html) installed and configured with access to Amazon Bedrock. + +2. [Python](https://www.python.org/downloads/) v3.11 or greater. The POC runs on Python. + +3. Snowflake account with appropriate permissions to create databases and tables. + +4. AWS account with permissions to access Amazon Bedrock services. + +## Steps + +1. Install Git (Optional step): + ```bash + # Amazon Linux / CentOS / RHEL: + sudo yum install -y git + # Ubuntu / Debian: + sudo apt-get install -y git + # Mac/Windows: Git is usually pre-installed + ``` + +2. Clone the repository to your local machine. + + ```bash + git clone https://github.com/AWS-Samples-GenAI-FSI/Sales-Analyst-Bedrock-Snowflake.git + + ``` + + The file structure of this POC is organized as follows: + + * `requirements.txt` - All dependencies needed for the application + * `app.py` - Main Streamlit application with UI components + * `setup.py` - Setup script for dependencies + * `src/bedrock/bedrock_helper.py` - Amazon Bedrock client wrapper + * `src/graph/workflow.py` - LangGraph workflow orchestration + * `src/vector_store/faiss_manager.py` - FAISS vector store for semantic search + * `src/utils/snowflake_connector.py` - Snowflake database connection management + * `src/utils/helpers.py` - Utility functions + * `src/utils/setup_utils.py` - Setup utilities + * `src/prompts/prompt_template.py` - Prompt management + * `src/prompts/prompts.yaml` - Structured prompts + * `src/monitoring/langfuse_monitor.py` - LangFuse monitoring integration + +3. Open the repository in your favorite code editor. In the terminal, navigate to the POC's folder: + ```bash + cd Sales-Analyst-Bedrock-Snowflake + ``` + +4. Configure the Python virtual environment, activate it: + ```bash + python -m venv .venv + source .venv/bin/activate # On Windows: .venv\Scripts\activate + ``` + +5. Install project dependencies: + ```bash + pip install -r requirements.txt + ``` + +6. Configure your credentials by editing the `.env` file and replacing the dummy values with your actual credentials: + + ```bash + # AWS Configuration (Required) + AWS_REGION=us-east-1 + AWS_ACCESS_KEY_ID=your_access_key_here + AWS_SECRET_ACCESS_KEY=your_secret_key_here + + # Snowflake Configuration (Required) + SNOWFLAKE_USER=your_username + SNOWFLAKE_PASSWORD=your_password + SNOWFLAKE_ACCOUNT=your_account_identifier + SNOWFLAKE_WAREHOUSE=your_warehouse + SNOWFLAKE_ROLE=your_role + ``` + +7. Start the application from your terminal: + ```bash + streamlit run app.py + ``` + +8. **Automatic Setup**: On first run, the application will automatically: + - Connect to your Snowflake account + - Check if Northwind sample database exists + - Download Northwind sample data if needed + - Create `SALES_ANALYST` database and `NORTHWIND` schema + - Load complete sample dataset (8 tables with sales data) + - Build vector store with schema metadata + - This process takes approximately 2-3 minutes + +9. **Start Analyzing**: Once setup is complete, you can ask natural language questions like: + - "What are the top 5 customers by order value?" + - "Show me the schema of the CUSTOMERS table" + - "Count the number of orders by country" + - "What's the distribution of order priorities?" + - "What's the average order value by customer?" + - "Which products are most popular?" + +## Architecture Highlights + +- **Zero Configuration**: Automatic database setup and sample data loading +- **Context-Aware AI**: Semantic search for intelligent SQL generation using FAISS +- **Multi-Step AI Pipeline**: Query understanding β†’ Context retrieval β†’ SQL generation β†’ Analysis +- **Workflow Orchestration**: LangGraph-powered structured analysis workflow +- **Performance Monitoring**: LangFuse integration for AI interaction tracking +- **Extensible Design**: Modular architecture for easy customization + +### Built with: + +- Amazon Bedrock: AI/ML models for natural language processing +- Snowflake: Cloud data warehouse for fast analytics +- FAISS: Vector database for semantic search +- Streamlit: Web interface +- LangGraph: Workflow orchestration +- LangFuse: AI monitoring and observability + +### Database Structure +After setup, you'll have access to: +- **CUSTOMERS** - Customer information +- **ORDERS** - Order headers +- **ORDER_DETAILS** - Order line items +- **PRODUCTS** - Product catalog +- **CATEGORIES** - Product categories +- **SUPPLIERS** - Supplier information +- **EMPLOYEES** - Employee data +- **SHIPPERS** - Shipping companies + +## AI-Powered Workflow +The application uses **LangGraph** and **Amazon Bedrock** to create an intelligent analysis workflow: + +1. 🧠 **Understand Query**: AI analyzes your natural language question +2. πŸ” **Retrieve Context**: Finds relevant table/column metadata using FAISS vector search +3. πŸ’» **Generate SQL**: Creates optimized SQL query using context +4. ⚑ **Execute Query**: Runs SQL against your Snowflake database +5. πŸ“Š **Analyze Results**: Provides business insights and explanations + +### Key Features +- **Natural Language to SQL**: No SQL knowledge required +- **Intelligent Context**: Understands your database schema automatically +- **Error Recovery**: Handles and recovers from query errors +- **Performance Monitoring**: Tracks AI interactions with LangFuse +- **Persistent Caching**: Speeds up repeated queries + +## Monitoring (Optional) + +**LangFuse Integration** provides: +- πŸ“Š AI interaction tracking +- πŸ”„ Workflow step monitoring +- 🚨 Error logging and analysis +- ⚑ Performance metrics + +To enable, update your credentials in the connector file or set environment variables. + +## Troubleshooting +### Common Issues +- **"Connection failed" errors**: + - Verify your Snowflake credentials are correct + - Check your account identifier format + - Ensure your user has appropriate permissions + - If still connecting to old account, clear cached environment variables: + ```bash + unset SNOWFLAKE_ACCOUNT SNOWFLAKE_USER SNOWFLAKE_PASSWORD + ``` + Then restart the app + +- **"Setup fails" or timeouts**: + - Check your Snowflake warehouse is running + - Verify network connectivity to Snowflake + - Ensure sufficient compute resources + +- **"Credentials not found"**: + - Make sure you updated the `.env` file with your actual credentials + - Make sure `.env` file is in the same directory as `app.py` + - Verify no extra spaces in your credential values + - Check that you saved the `.env` file after editing + +- **"App won't start"**: + - Ensure Python 3.11+ is installed: `python --version` + - Install requirements: `pip install -r requirements.txt` + - Try: `python -m streamlit run app.py` + +- **"AWS Bedrock access denied"**: + - Verify your AWS credentials are configured + - Check your IAM permissions for Bedrock access + - Ensure you're in a supported AWS region + +### Getting Help +- Check Snowflake query history for detailed error messages +- Review AWS CloudWatch logs for Bedrock API calls +- Ensure your Snowflake account has no usage limits blocking queries + +## How-To Guide +For detailed usage instructions and advanced configuration, visit the application's help section within the Streamlit interface. \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/app.py b/Sales-Analyst-Bedrock-Snowflake/app.py new file mode 100644 index 00000000..43b88254 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/app.py @@ -0,0 +1,480 @@ +""" +GenAI Sales Analyst - Main application file. +""" +import streamlit as st +import pandas as pd +import time +import os +import pickle +import numpy as np +from datetime import datetime +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Import components +from src.bedrock.bedrock_helper import BedrockHelper +from src.vector_store.faiss_manager import FAISSManager +from src.monitoring.langfuse_monitor import LangfuseMonitor +from src.graph.workflow import AnalysisWorkflow +from src.utils.snowflake_connector import ( + get_snowflake_connection, + execute_query, + get_available_databases, + get_available_schemas, + get_available_tables, + get_table_columns +) +from src.utils.northwind_bootstrapper import bootstrap_northwind, check_northwind_exists + +def initialize_components(): + """ + Initialize application components. + + Returns: + Dictionary of initialized components + """ + # Get environment variables + aws_region = os.getenv('AWS_REGION', 'us-east-1') + s3_bucket = os.getenv('S3_BUCKET', 'your-bucket-name') + + # Initialize Bedrock client + bedrock = BedrockHelper(region_name=aws_region) + + # Initialize vector store + vector_store = FAISSManager( + bedrock_client=bedrock, + s3_bucket=s3_bucket + ) + + # Initialize monitoring + langfuse_public_key = os.getenv('LANGFUSE_PUBLIC_KEY', '') + langfuse_secret_key = os.getenv('LANGFUSE_SECRET_KEY', '') + langfuse_host = os.getenv('LANGFUSE_HOST', '') + + monitor = None + if langfuse_public_key and langfuse_secret_key: + try: + monitor = LangfuseMonitor( + public_key=langfuse_public_key, + secret_key=langfuse_secret_key, + host=langfuse_host if langfuse_host else None + ) + except Exception as e: + st.sidebar.error(f"Error initializing LangFuse: {str(e)}") + + # Initialize workflow + workflow = AnalysisWorkflow( + bedrock_helper=bedrock, + vector_store=vector_store, + monitor=monitor + ) + + return { + 'bedrock': bedrock, + 'vector_store': vector_store, + 'monitor': monitor, + 'workflow': workflow + } + + +def load_all_metadata(vector_store, show_progress=False): + """ + Load metadata from tables in Snowflake sample database. + + Args: + vector_store: Vector store to add metadata to + show_progress: Whether to show progress messages + + Returns: + DataFrame with metadata + """ + # Check if metadata cache exists + cache_file = "metadata_cache.pkl" + if os.path.exists(cache_file): + try: + with open(cache_file, "rb") as f: + cached_data = pickle.load(f) + vector_store.texts = cached_data.get("texts", []) + vector_store.metadata = cached_data.get("metadata", []) + embeddings_array = np.array(cached_data.get("embeddings", [])).astype('float32') + if len(embeddings_array) > 0: + vector_store.index.add(embeddings_array) + if show_progress: + st.sidebar.success(f"βœ… Loaded metadata from cache ({len(vector_store.texts)} items)") + return cached_data.get("dataframe") + except Exception as e: + if show_progress: + st.sidebar.error(f"Error loading cache: {str(e)}") + + # Target tables from Northwind database + database = "SALES_ANALYST" + schema = "NORTHWIND" + tables = [ + "CUSTOMERS", "ORDERS", "ORDER_DETAILS", "PRODUCTS", + "CATEGORIES", "SUPPLIERS", "EMPLOYEES", "SHIPPERS" + ] + + all_metadata = [] + progress_text = "Loading metadata..." if show_progress else None + + # Create progress bar if showing progress + if show_progress: + total_tables = len(tables) + progress_bar = st.sidebar.progress(0) + table_count = 0 + + # Load metadata for each table + for table in tables: + try: + if show_progress: + table_count += 1 + progress_bar.progress(table_count / total_tables, text=f"Loading {table}...") + + metadata_df = get_table_columns(database, schema, table) + + # Add table description + table_desc = f"Table {table} in {schema}" + metadata_df['table_description'] = table_desc + metadata_df['database'] = database + metadata_df['schema'] = schema + + all_metadata.append(metadata_df) + + # Also get sample data to enrich metadata + try: + sample_data = execute_query(f"SELECT * FROM {database}.{schema}.{table} LIMIT 5") + if sample_data: + sample_values = {} + for col in metadata_df['column_name'].tolist(): + if col in sample_data[0]: + values = [str(row[col]) for row in sample_data if row[col] is not None][:3] + if values: + sample_values[col] = ", ".join(values) + + # Add sample values to metadata + for i, row in metadata_df.iterrows(): + col = row['column_name'] + if col in sample_values: + metadata_df.at[i, 'sample_values'] = sample_values[col] + except Exception: + # Silently continue if sample data fails + pass + + except Exception as e: + if show_progress: + st.sidebar.error(f"Error loading metadata for {table}: {str(e)}") + + # Clear progress bar if showing progress + if show_progress: + progress_bar.empty() + + # Combine all metadata + if all_metadata: + combined_metadata = pd.concat(all_metadata) + + # Add to vector store + texts = [] + metadatas = [] + + for _, row in combined_metadata.iterrows(): + # Create rich text description + sample_values = f", Sample values: {row.get('sample_values', 'N/A')}" if 'sample_values' in row else "" + text = f"Table: {row['database']}.{row['schema']}.{row['table_name']}, Column: {row['column_name']}, Type: {row['data_type']}, Description: {row['description']}{sample_values}" + texts.append(text) + metadatas.append(row.to_dict()) + + # Get embeddings and add to vector store + embeddings = [] + for text in texts: + embedding = vector_store.bedrock_client.get_embeddings(text) + embeddings.append(embedding) + + # Convert embeddings to numpy array + embeddings_array = np.array(embeddings).astype('float32') + + # Add to vector store + vector_store.texts = texts + vector_store.metadata = metadatas + vector_store.index.add(embeddings_array) + + # Save to cache + try: + with open(cache_file, "wb") as f: + pickle.dump({ + "texts": texts, + "metadata": metadatas, + "embeddings": embeddings, + "dataframe": combined_metadata + }, f) + except Exception as e: + if show_progress: + st.sidebar.warning(f"Could not save metadata cache: {str(e)}") + + return combined_metadata + + return None + + +def main(): + """ + Main application function. + """ + # Set page config + st.set_page_config( + page_title="Sales Data Analyst", + page_icon="πŸ“Š", + layout="wide" + ) + + # Hide Streamlit branding + hide_streamlit_style = """ + + """ + st.markdown(hide_streamlit_style, unsafe_allow_html=True) + + # Custom CSS for other elements + st.markdown(""" + + """, unsafe_allow_html=True) + + # Header with direct HTML and inline styles + st.markdown('

Sales Data Analyst

', unsafe_allow_html=True) + st.markdown('

(Powered by Amazon Bedrock and Snowflake)

', unsafe_allow_html=True) + st.markdown('
', unsafe_allow_html=True) + + # Initialize components + components = initialize_components() + + # Test Snowflake connection and auto-setup database + try: + conn = get_snowflake_connection() + st.sidebar.success("βœ… Connected to Snowflake") + conn.close() + + # Auto-create Northwind database if it doesn't exist + if not check_northwind_exists(): + st.sidebar.info("πŸ”„ Setting up Northwind database...") + with st.spinner("Creating Northwind database with sample data..."): + success = bootstrap_northwind(show_progress=True) + if success: + st.sidebar.success("βœ… Northwind database created successfully!") + else: + st.sidebar.error("❌ Failed to create Northwind database") + return + else: + st.sidebar.success("βœ… Northwind database ready") + + except Exception as e: + st.sidebar.error(f"❌ Snowflake connection failed: {str(e)}") + return + + # Load metadata on startup if not already loaded + if 'metadata_loaded' not in st.session_state or not st.session_state.metadata_loaded: + with st.spinner("Loading database metadata..."): + # Add small delay to ensure database is ready + import time + time.sleep(2) + metadata_df = load_all_metadata(components['vector_store'], show_progress=True) + if metadata_df is not None and len(metadata_df) > 0: + st.session_state.metadata_df = metadata_df + st.session_state.metadata_loaded = True + st.session_state.metadata_count = len(metadata_df) + st.sidebar.success(f"βœ… Loaded metadata for {len(metadata_df)} columns") + else: + st.sidebar.error("❌ Failed to load metadata - try reloading the page") + st.session_state.metadata_loaded = False + + # Sidebar + with st.sidebar: + st.header("Settings") + + # Monitoring status + if components['monitor'] and components['monitor'].enabled: + st.success("βœ… LangFuse monitoring enabled") + else: + st.warning("⚠️ LangFuse monitoring disabled") + + # Workflow status + if components['workflow']: + st.success("βœ… Analysis workflow enabled") + + # Reload metadata button + if st.button("πŸ”„ Reload Metadata", key="reload_metadata"): + with st.spinner("Reloading database metadata..."): + st.session_state.metadata_loaded = False + metadata_df = load_all_metadata(components['vector_store'], show_progress=True) + if metadata_df is not None and len(metadata_df) > 0: + st.session_state.metadata_df = metadata_df + st.session_state.metadata_loaded = True + st.session_state.metadata_count = len(metadata_df) + st.success(f"βœ… Reloaded metadata for {len(metadata_df)} columns") + st.rerun() + else: + st.error("❌ Failed to reload metadata") + + # Available data section moved to sidebar + st.header("πŸ“‹ Available Data") + st.markdown(""" + **🏒 Business Data:** + - πŸ‘₯ **Customers** - Company details, contacts, locations + - πŸ“¦ **Orders** - Order dates, shipping info, freight costs + - πŸ›οΈ **Order Details** - Products, quantities, prices, discounts + + **🏭 Product Catalog:** + - 🎯 **Products** - Names, prices, stock levels + - πŸ“‚ **Categories** - Product groupings and descriptions + - 🚚 **Suppliers** - Vendor information and contacts + + **πŸ‘¨β€πŸ’Ό Operations:** + - πŸ‘” **Employees** - Staff details and hierarchy + - πŸš› **Shippers** - Delivery companies and contacts + """) + + # Show available databases and schemas + with st.expander("Database Explorer", expanded=False): + if st.button("Show Databases"): + try: + databases = get_available_databases() + st.write("Available databases:") + st.write(", ".join(databases)) + except Exception as e: + st.error(f"Error listing databases: {str(e)}") + + # Main content area - use full width for col1 + col1 = st.container() + + with col1: + st.markdown('

Ask questions about your sales data

', unsafe_allow_html=True) + st.markdown('

You can ask about customer orders, product sales, and more.

', unsafe_allow_html=True) + + # Examples + with st.expander("πŸ’‘ Example questions", expanded=False): + st.markdown(""" + **βœ… Try these working questions:** + + 1. **What are the top 10 customers by total order value?** + 2. **Which products generate the most revenue?** + 3. **What's the average order value by country?** + 4. **Which product categories sell the most?** + 5. **What are the top 5 most expensive products?** + 6. **How many orders come from each country?** + 7. **Which countries have the highest average order values?** + 8. **Who are our most frequent customers?** + 9. **Which suppliers provide the most products?** + 10. **Which employees process the most orders?** + """) + + # Question input + question = st.text_input( + "Ask your question:", + placeholder="e.g., What are the top 5 customers by order value?" + ) + + # Process question + if question: + if 'metadata_df' not in st.session_state or not st.session_state.get('metadata_loaded', False): + st.error("Metadata not loaded. Please click 'Reload Metadata' button in the sidebar.") + return + + try: + # Execute workflow + with st.spinner("Processing your question..."): + result = components['workflow'].execute(question, execute_query) + + # Display workflow steps + with st.expander("Workflow Steps", expanded=False): + steps = result.get("steps_completed", []) + for step in steps: + if "error" in step: + st.markdown(f'
{step}
', unsafe_allow_html=True) + else: + st.markdown(f'
{step}
', unsafe_allow_html=True) + + # Display error if any + if "error" in result: + st.error(result.get("friendly_error", result["error"])) + + # Display SQL if generated + if "generated_sql" in result: + with st.expander("Generated SQL", expanded=True): + st.code(result["generated_sql"], language="sql") + + # Display results if available + if "query_results" in result: + st.write(f"Query executed in {result.get('execution_time', 0):.2f} seconds, returned {len(result['query_results'])} rows") + with st.expander("Query Results", expanded=True): + st.dataframe(result["query_results"]) + + # Display analysis + if "analysis" in result: + st.subheader("Analysis") + st.write(result["analysis"]) + + # Save to history + if 'history' not in st.session_state: + st.session_state.history = [] + + st.session_state.history.append({ + 'question': question, + 'sql': result.get('generated_sql', ''), + 'results': result.get('query_results', [])[:10], # Store only first 10 rows + 'analysis': result.get('analysis', ''), + 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + }) + + except Exception as e: + st.error(f"Error: {str(e)}") + + # Show history + if 'history' in st.session_state and st.session_state.history: + with st.expander("Query History", expanded=False): + for i, item in enumerate(reversed(st.session_state.history[-5:])): # Show last 5 queries + st.write(f"**{item['timestamp']}**: {item['question']}") + if st.button(f"Show details", key=f"history_{i}"): + st.code(item['sql'], language="sql") + st.dataframe(item['results']) + st.write(item['analysis']) + st.divider() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/images/architecture.png b/Sales-Analyst-Bedrock-Snowflake/images/architecture.png new file mode 100644 index 00000000..220532d8 Binary files /dev/null and b/Sales-Analyst-Bedrock-Snowflake/images/architecture.png differ diff --git a/Sales-Analyst-Bedrock-Snowflake/images/demo.gif b/Sales-Analyst-Bedrock-Snowflake/images/demo.gif new file mode 100644 index 00000000..1a18a070 Binary files /dev/null and b/Sales-Analyst-Bedrock-Snowflake/images/demo.gif differ diff --git a/Sales-Analyst-Bedrock-Snowflake/requirements.txt b/Sales-Analyst-Bedrock-Snowflake/requirements.txt new file mode 100644 index 00000000..2b8939f7 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/requirements.txt @@ -0,0 +1,9 @@ +streamlit>=1.24.0 +pandas>=1.5.3 +numpy>=1.24.3 +boto3>=1.28.0 +snowflake-connector-python>=3.0.0 +faiss-cpu>=1.7.4 +python-dotenv>=1.0.0 +langfuse>=3.0.1 +requests>=2.31.0 \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/setup.py b/Sales-Analyst-Bedrock-Snowflake/setup.py new file mode 100644 index 00000000..21e730c6 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/setup.py @@ -0,0 +1,7 @@ +""" +Setup script for the GenAI Sales Analyst application. +""" +from src.utils.setup_utils import run_setup + +if __name__ == "__main__": + run_setup() \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/__init__.py new file mode 100644 index 00000000..5ea483bc --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/__init__.py @@ -0,0 +1 @@ +# GenAI Sales Analyst package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/bedrock/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/bedrock/__init__.py new file mode 100644 index 00000000..c672f0a3 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/bedrock/__init__.py @@ -0,0 +1,3 @@ +""" +Amazon Bedrock integration for the GenAI Sales Analyst application. +""" \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/bedrock/bedrock_helper.py b/Sales-Analyst-Bedrock-Snowflake/src/bedrock/bedrock_helper.py new file mode 100644 index 00000000..3b9e6493 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/bedrock/bedrock_helper.py @@ -0,0 +1,91 @@ +""" +Amazon Bedrock helper for the GenAI Sales Analyst application. +""" +import boto3 +import json +from typing import List, Dict, Any, Optional + + +class BedrockHelper: + """ + Helper class for Amazon Bedrock operations. + """ + + def __init__(self, region_name: str = 'us-east-1'): + """ + Initialize the Bedrock helper. + + Args: + region_name: AWS region name + """ + self.bedrock_runtime = boto3.client( + service_name='bedrock-runtime', + region_name=region_name + ) + + def invoke_model(self, + prompt: str, + model_id: str = "anthropic.claude-3-sonnet-20240229-v1:0", + max_tokens: int = 4096, + temperature: float = 0.7) -> str: + """ + Invoke a Bedrock model with a prompt. + + Args: + prompt: Input prompt text + model_id: Bedrock model ID + max_tokens: Maximum tokens to generate + temperature: Temperature for generation + + Returns: + Model response text + """ + # Format prompt for Claude models + if "anthropic" in model_id: + formatted_prompt = f"Human: {prompt}\n\nAssistant:" + else: + formatted_prompt = prompt + + body = json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": max_tokens, + "temperature": temperature, + "messages": [ + { + "role": "user", + "content": prompt + } + ] + }) + + try: + response = self.bedrock_runtime.invoke_model( + modelId=model_id, + body=body + ) + response_body = json.loads(response['body'].read()) + return response_body['content'][0]['text'] + except Exception as e: + print(f"Error invoking Bedrock: {str(e)}") + raise + + def get_embeddings(self, text: str) -> List[float]: + """ + Get embeddings for a text using Bedrock. + + Args: + text: Input text + + Returns: + List of embedding values + """ + try: + response = self.bedrock_runtime.invoke_model( + modelId="amazon.titan-embed-text-v1", + body=json.dumps({"inputText": text}) + ) + response_body = json.loads(response['body'].read()) + return response_body['embedding'] + except Exception as e: + print(f"Error getting embeddings: {str(e)}") + raise \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/config/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/config/__init__.py new file mode 100644 index 00000000..be6a4fc5 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/config/__init__.py @@ -0,0 +1 @@ +# Config package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/config/settings.py b/Sales-Analyst-Bedrock-Snowflake/src/config/settings.py new file mode 100644 index 00000000..6c022033 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/config/settings.py @@ -0,0 +1,35 @@ +""" +Configuration settings for the GenAI Sales Analyst application. +""" +import os +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +# AWS Bedrock settings +AWS_REGION = os.getenv("AWS_REGION", "us-east-1") +DEFAULT_MODEL_ID = "amazon.nova-pro-v1:0" + +# Snowflake settings +SNOWFLAKE_ACCOUNT = os.getenv("SNOWFLAKE_ACCOUNT", "") +SNOWFLAKE_USER = os.getenv("SNOWFLAKE_USER", "") +SNOWFLAKE_PASSWORD = os.getenv("SNOWFLAKE_PASSWORD", "") +SNOWFLAKE_WAREHOUSE = os.getenv("SNOWFLAKE_WAREHOUSE", "COMPUTE_WH") +SNOWFLAKE_ROLE = os.getenv("SNOWFLAKE_ROLE", "ACCOUNTADMIN") + +# Default database and schema +DEFAULT_DATABASE = "SNOWFLAKE_SAMPLE_DATA" +DEFAULT_SCHEMA = "TPCH_SF1" + +# Cache settings +SCHEMA_CACHE_TTL = 3600 # Cache schema information for 1 hour +SCHEMA_CACHE_SIZE = 100 # Maximum number of schemas to cache + +# UI settings +PAGE_TITLE = "GenAI Sales Analyst – Powered by Amazon Bedrock" +PAGE_LAYOUT = "wide" + +# Assets paths +ASSETS_FOLDER = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "assets") +IMAGES_FOLDER = os.path.join(ASSETS_FOLDER, "images") \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/graph/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/graph/__init__.py new file mode 100644 index 00000000..ca875de7 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/graph/__init__.py @@ -0,0 +1,3 @@ +""" +LangGraph workflow components for the GenAI Sales Analyst application. +""" \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/graph/edges.py b/Sales-Analyst-Bedrock-Snowflake/src/graph/edges.py new file mode 100644 index 00000000..78a78772 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/graph/edges.py @@ -0,0 +1,58 @@ +""" +Edge definitions for LangGraph workflows. +""" +from typing import Dict, Any, Callable + + +class WorkflowEdges: + """ + Edge definitions for LangGraph workflows. + """ + + @staticmethod + def route_to_sql(state: Dict[str, Any]) -> str: + """ + Conditional edge router based on query type. + + Args: + state: Current workflow state + + Returns: + Name of the next node to execute + """ + return ( + 'generate_sql' + if state['query_analysis']['type'] == 'sql' + else 'analyze_data' + ) + + @staticmethod + def get_conditional_edges() -> Dict[str, Dict[str, str]]: + """ + Get conditional edge definitions. + + Returns: + Dictionary of conditional edge definitions + """ + return { + 'retrieve_context': { + 'condition': WorkflowEdges.route_to_sql, + 'edges': { + 'generate_sql': 'analyze_data', + 'analyze_data': 'format_response' + } + } + } + + @staticmethod + def get_direct_edges() -> Dict[str, str]: + """ + Get direct edge definitions. + + Returns: + Dictionary of direct edge definitions + """ + return { + 'understand_query': 'retrieve_context', + 'analyze_data': 'format_response' + } \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/graph/nodes.py b/Sales-Analyst-Bedrock-Snowflake/src/graph/nodes.py new file mode 100644 index 00000000..a5ee59bf --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/graph/nodes.py @@ -0,0 +1,155 @@ +""" +Individual node definitions for LangGraph workflows. +""" +from typing import Dict, Any, List +import json +from datetime import datetime + + +class WorkflowNodes: + """ + Node implementations for LangGraph workflows. + """ + + def __init__(self, bedrock_helper, vector_store, monitor): + """ + Initialize workflow nodes with required components. + + Args: + bedrock_helper: Client for Amazon Bedrock API + vector_store: Vector store for similarity search + monitor: Monitoring client + """ + self.bedrock = bedrock_helper + self.vector_store = vector_store + self.monitor = monitor + + def understand_query(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Understand and classify the user query. + + Args: + state: Current workflow state + + Returns: + Updated workflow state + """ + query = state['query'] + + prompt = f"""Analyze this query and classify it: + Query: {query} + Determine: + 1. Query type (analysis/sql/metadata) + 2. Required data sources + 3. Time frame mentioned + Return as JSON.""" + + response = self.bedrock.invoke_model(prompt) + analysis = json.loads(response) + + return { + **state, + "query_analysis": analysis + } + + def retrieve_context(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Retrieve relevant context from vector store. + + Args: + state: Current workflow state + + Returns: + Updated workflow state with relevant context + """ + query = state['query'] + similar_docs = self.vector_store.similarity_search(query, k=3) + + return { + **state, + "relevant_context": similar_docs + } + + def generate_sql(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Generate SQL if needed based on query analysis. + + Args: + state: Current workflow state + + Returns: + Updated workflow state with generated SQL + """ + if state['query_analysis']['type'] != 'sql': + return state + + prompt = f"""Given: + Query: {state['query']} + Context: {state['relevant_context']} + Generate SQL query.""" + + sql = self.bedrock.invoke_model(prompt) + + return { + **state, + "generated_sql": sql + } + + def analyze_data(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform data analysis based on the query and context. + + Args: + state: Current workflow state + + Returns: + Updated workflow state with analysis results + """ + context = state['relevant_context'] + query = state['query'] + + prompt = f"""Analyze: + Question: {query} + Context: {context} + Provide detailed analysis.""" + + analysis = self.bedrock.invoke_model(prompt) + + return { + **state, + "analysis": analysis + } + + def format_response(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Format the final response based on workflow results. + + Args: + state: Current workflow state + + Returns: + Updated workflow state with formatted response + """ + if 'generated_sql' in state: + response = f"""Analysis includes SQL query: + {state['generated_sql']} + + Analysis: + {state['analysis']}""" + else: + response = state['analysis'] + + # Log to monitoring + self.monitor.log_interaction( + prompt=state['query'], + response=response, + metadata={ + "workflow_type": "analysis", + "query_analysis": state['query_analysis'] + } + ) + + return { + **state, + "final_response": response + } \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/graph/workflow.py b/Sales-Analyst-Bedrock-Snowflake/src/graph/workflow.py new file mode 100644 index 00000000..cff039b9 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/graph/workflow.py @@ -0,0 +1,477 @@ +""" +LangGraph workflow for the GenAI Sales Analyst application. +""" +from typing import Dict, Any, List, Tuple +import json +from datetime import datetime + + +class AnalysisWorkflow: + """ + LangGraph workflow for sales data analysis. + """ + + def __init__(self, bedrock_helper, vector_store, monitor=None): + """ + Initialize the analysis workflow. + + Args: + bedrock_helper: Client for Amazon Bedrock API + vector_store: Vector store for similarity search + monitor: Optional monitoring client + """ + self.bedrock = bedrock_helper + self.vector_store = vector_store + self.monitor = monitor + + def understand_query(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Understand and classify the user query. + + Args: + state: Current workflow state + + Returns: + Updated workflow state + """ + query = state['query'] + + prompt = f"""Analyze this query and classify it: + +Query: {query} + +Determine: +1. Query type (analysis/sql/metadata/comparison) +2. Required data sources or tables +3. Time frame mentioned (if any) +4. Specific metrics requested (if any) + +Return as JSON with these fields. +""" + + try: + response = self.bedrock.invoke_model(prompt) + + # Log to monitoring if available + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=prompt, + response=response, + metadata={ + "step_name": "understand_query", + "query": query + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + # Parse the response as JSON + try: + analysis = json.loads(response) + except json.JSONDecodeError: + # If not valid JSON, create a simple structure + analysis = { + "type": "analysis", + "data_sources": [], + "time_frame": "not specified", + "metrics": [] + } + + return { + **state, + "query_analysis": analysis, + "steps_completed": state.get("steps_completed", []) + ["understand_query"] + } + except Exception as e: + return { + **state, + "error": f"Error in understand_query: {str(e)}", + "steps_completed": state.get("steps_completed", []) + ["understand_query_error"] + } + + def retrieve_context(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Retrieve relevant context from vector store. + + Args: + state: Current workflow state + + Returns: + Updated workflow state with relevant context + """ + if "error" in state: + return state + + query = state['query'] + + try: + # Get similar documents from vector store + similar_docs = self.vector_store.similarity_search(query, k=5) + + # Handle empty results + if not similar_docs: + # If no similar documents found, create a direct SQL query + if "schema" in query.lower() and "customer" in query.lower(): + # Special case for schema queries + return { + **state, + "generated_sql": "USE DATABASE SNOWFLAKE_SAMPLE_DATA;\nDESCRIBE TABLE TPCH_SF1.CUSTOMER;", + "skip_context": True, + "steps_completed": state.get("steps_completed", []) + ["retrieve_context", "direct_sql"] + } + else: + # For other queries with no context + return { + **state, + "relevant_context": [], + "steps_completed": state.get("steps_completed", []) + ["retrieve_context", "no_results"] + } + + # Log to monitoring if available + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=query, + response=str(similar_docs), + metadata={ + "step_name": "retrieve_context", + "num_results": len(similar_docs) + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + return { + **state, + "relevant_context": similar_docs, + "steps_completed": state.get("steps_completed", []) + ["retrieve_context"] + } + except Exception as e: + return { + **state, + "error": f"Error in retrieve_context: {str(e)}", + "steps_completed": state.get("steps_completed", []) + ["retrieve_context_error"] + } + + def generate_sql(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Generate SQL based on the query and context. + + Args: + state: Current workflow state + + Returns: + Updated workflow state with generated SQL + """ + if "error" in state: + return state + + # If we already have direct SQL, skip this step + if "skip_context" in state and state["skip_context"] and "generated_sql" in state: + return state + + query = state['query'] + context = state.get('relevant_context', []) + + # Create context string from relevant documents + context_str = "\n".join([f"- {doc['text']}" for doc in context]) + + # If no context is available, use a more generic prompt + if not context_str: + prompt = f"""Generate a SQL query to answer this question: + +Question: {query} + +First, make sure to use the SNOWFLAKE_SAMPLE_DATA database. +This database has the TPCH_SF1 schema with these tables and their important columns: +- CUSTOMER: C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT +- ORDERS: O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT +- LINEITEM: L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE +- PART: P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT +- PARTSUPP: PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT +- SUPPLIER: S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT +- NATION: N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT +- REGION: R_REGIONKEY, R_NAME, R_COMMENT + +IMPORTANT: +1. Start your query with 'USE DATABASE SNOWFLAKE_SAMPLE_DATA;' +2. Always use fully qualified table names including schema (e.g., TPCH_SF1.CUSTOMER) +3. Use the EXACT column names as listed above, including the prefixes (C_, O_, L_, etc.) +4. When using table aliases, reference columns with the correct prefix (e.g., c.C_CUSTKEY, o.O_ORDERDATE) + +Generate ONLY the SQL query without any explanation. +""" + else: + prompt = f"""Generate a SQL query to answer this question: + +Question: {query} + +First, make sure to use the SNOWFLAKE_SAMPLE_DATA database. +Relevant context: +{context_str} + +IMPORTANT: +1. Start your query with 'USE DATABASE SNOWFLAKE_SAMPLE_DATA;' +2. Always use fully qualified table names including schema (e.g., TPCH_SF1.CUSTOMER) +3. Use the EXACT column names from the database, including the prefixes (C_, O_, L_, etc.) +4. When using table aliases, reference columns with the correct prefix (e.g., c.C_CUSTKEY, o.O_ORDERDATE) +5. Common column names in TPCH_SF1 schema: + - CUSTOMER: C_CUSTKEY, C_NAME, C_MKTSEGMENT + - ORDERS: O_ORDERKEY, O_CUSTKEY, O_TOTALPRICE, O_ORDERDATE + - LINEITEM: L_ORDERKEY, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT + +Generate ONLY the SQL query without any explanation. +""" + + try: + sql = self.bedrock.invoke_model(prompt) + + # Log to monitoring if available + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=prompt, + response=sql, + metadata={ + "step_name": "generate_sql", + "query": query + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + # Ensure SQL starts with USE DATABASE + if "use database" not in sql.lower(): + sql = f"USE DATABASE SNOWFLAKE_SAMPLE_DATA;\n{sql}" + + return { + **state, + "generated_sql": sql.strip(), + "steps_completed": state.get("steps_completed", []) + ["generate_sql"] + } + except Exception as e: + return { + **state, + "error": f"Error in generate_sql: {str(e)}", + "steps_completed": state.get("steps_completed", []) + ["generate_sql_error"] + } + + def analyze_results(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Analyze query results and provide an answer. + + Args: + state: Current workflow state + + Returns: + Updated workflow state with analysis + """ + if "error" in state: + return state + + query = state['query'] + sql = state.get('generated_sql', '') + results = state.get('query_results', []) + + # Convert results to string representation + if not results: + analysis = "No results found for this query." + + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=query, + response=analysis, + metadata={ + "step_name": "analyze_results", + "results_count": 0 + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + return { + **state, + "analysis": analysis, + "steps_completed": state.get("steps_completed", []) + ["analyze_results"] + } + + results_str = "\n".join([str(row) for row in results[:10]]) + if len(results) > 10: + results_str += f"\n... and {len(results) - 10} more rows" + + prompt = f"""Analyze these query results to answer the user's question: + +Question: {query} + +SQL Query: +{sql} + +Query Results (first 10 rows): +{results_str} + +Provide a clear, concise analysis that directly answers the question. Include key insights from the data. +""" + + try: + analysis = self.bedrock.invoke_model(prompt) + + # Log to monitoring if available + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_interaction( + prompt=prompt, + response=analysis, + metadata={ + "step_name": "analyze_results", + "results_count": len(results) + }, + trace_id=state.get('trace_id') + ) + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + + return { + **state, + "analysis": analysis.strip(), + "steps_completed": state.get("steps_completed", []) + ["analyze_results"] + } + except Exception as e: + return { + **state, + "error": f"Error in analyze_results: {str(e)}", + "steps_completed": state.get("steps_completed", []) + ["analyze_results_error"] + } + + def handle_error(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + Handle errors in the workflow. + + Args: + state: Current workflow state + + Returns: + Updated workflow state with error handling + """ + error = state.get('error', 'Unknown error') + + # Log to monitoring if available + if self.monitor and self.monitor.enabled: + try: + self.monitor.log_error( + error_message=error, + metadata={ + "query": state.get('query', ''), + "steps_completed": state.get('steps_completed', []) + } + ) + except Exception as e: + print(f"Error logging error to LangFuse: {str(e)}") + + # Generate a user-friendly error message + prompt = f"""An error occurred while processing this query: + +Query: {state.get('query', '')} + +Error: {error} + +Generate a user-friendly error message explaining what went wrong and suggesting how to fix it. +""" + + try: + friendly_message = self.bedrock.invoke_model(prompt) + except Exception: + friendly_message = f"Sorry, an error occurred: {error}. Please try rephrasing your question." + + return { + **state, + "error_handled": True, + "friendly_error": friendly_message.strip(), + "steps_completed": state.get("steps_completed", []) + ["handle_error"] + } + + def execute(self, query: str, execute_query_func=None) -> Dict[str, Any]: + """ + Execute the analysis workflow. + + Args: + query: User query string + execute_query_func: Function to execute SQL queries + + Returns: + Final workflow state + """ + # Create trace ID for monitoring + trace_id = None + if self.monitor and self.monitor.enabled: + try: + from uuid import uuid4 + trace_id = f"workflow-{uuid4()}" + except Exception: + pass + + # Initialize state + state = { + "query": query, + "timestamp": datetime.now().isoformat(), + "trace_id": trace_id, + "steps_completed": [] + } + + # Execute workflow steps manually instead of using LangGraph + state = self.understand_query(state) + + if "error" not in state: + state = self.retrieve_context(state) + + if "error" not in state: + state = self.generate_sql(state) + + # Execute SQL if available and no errors + if "generated_sql" in state and "error" not in state and execute_query_func: + try: + start_time = datetime.now() + results = execute_query_func(state["generated_sql"]) + end_time = datetime.now() + execution_time = (end_time - start_time).total_seconds() + + state["query_results"] = results + state["execution_time"] = execution_time + + # Analyze results + state = self.analyze_results(state) + + except Exception as e: + state["error"] = f"Error executing SQL: {str(e)}" + state = self.handle_error(state) + elif "error" in state: + state = self.handle_error(state) + + # Log complete workflow to monitoring + if self.monitor and self.monitor.enabled: + try: + steps = [] + for step in state.get("steps_completed", []): + step_data = { + "name": step + } + steps.append(step_data) + + self.monitor.log_workflow( + workflow_name="analysis_workflow", + steps=steps, + metadata={ + "query": query, + "execution_time": state.get("execution_time"), + "error": state.get("error") + } + ) + except Exception as e: + print(f"Error logging workflow to LangFuse: {str(e)}") + + return state \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/models/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/models/__init__.py new file mode 100644 index 00000000..d8cfe8af --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/models/__init__.py @@ -0,0 +1 @@ +# Models package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/models/sql_generator.py b/Sales-Analyst-Bedrock-Snowflake/src/models/sql_generator.py new file mode 100644 index 00000000..a4c7f9a1 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/models/sql_generator.py @@ -0,0 +1,83 @@ +""" +SQL generation model using Amazon Bedrock. +""" +import streamlit as st +from ..utils.bedrock_client import invoke_bedrock_model +from ..utils.query_processor import get_cached_schema_context, extract_sql_from_response +from ..config.settings import DEFAULT_MODEL_ID + + +class SQLGenerator: + """ + SQL Generator class using Amazon Bedrock. + """ + + def __init__(self, model_id=DEFAULT_MODEL_ID): + """ + Initialize the SQL Generator. + + Args: + model_id (str, optional): The model ID to use. Defaults to DEFAULT_MODEL_ID. + """ + self.model_id = model_id + + def generate_sql(self, nl_query, database, schema): + """ + Generate SQL from natural language query. + + Args: + nl_query (str): Natural language query. + database (str): Database name. + schema (str): Schema name. + + Returns: + list: List of SQL queries. + """ + schema_context = get_cached_schema_context(database, schema) + schema_context = schema_context[:5000] if len(schema_context) > 5000 else schema_context + + message_content = ( + f"You are an expert SQL generator for Snowflake.\n" + f"Schema Information:\n{schema_context}\n\n" + f"Query: {nl_query}\n\n" + f"Generate a valid SQL query. Respond with only the query." + ) + + result = invoke_bedrock_model(message_content, self.model_id) + + if result: + sql_queries = extract_sql_from_response(result) + if not sql_queries: + st.warning("No valid query generated. Using fallback query.") + sql_queries = self._fallback_sql_query(nl_query) + return sql_queries + else: + return self._fallback_sql_query(nl_query) + + def _fallback_sql_query(self, nl_query): + """ + Generate fallback SQL query for common cases. + + Args: + nl_query (str): Natural language query. + + Returns: + list: List of SQL queries. + """ + # Simplified fallback logic for common cases + if "list tables" in nl_query.lower(): + return ["SHOW TABLES;"] + elif "sample records" in nl_query.lower(): + return ["SELECT * FROM LIMIT 5;"] + elif "highest number of sales orders" in nl_query.lower(): + return [""" + SELECT N.N_NAME, COUNT(O.O_ORDERKEY) AS ORDER_COUNT + FROM SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.NATION N + JOIN SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.CUSTOMER C ON N.N_NATIONKEY = C.C_NATIONKEY + JOIN SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.ORDERS O ON C.C_CUSTKEY = O.O_CUSTKEY + GROUP BY N.N_NAME + ORDER BY ORDER_COUNT DESC + LIMIT 1; + """] + else: + return [] \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/monitoring/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/monitoring/__init__.py new file mode 100644 index 00000000..a2bbaea1 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/monitoring/__init__.py @@ -0,0 +1,3 @@ +""" +Monitoring module for the GenAI Sales Analyst application. +""" \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/monitoring/langfuse_monitor.py b/Sales-Analyst-Bedrock-Snowflake/src/monitoring/langfuse_monitor.py new file mode 100644 index 00000000..1ee94778 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/monitoring/langfuse_monitor.py @@ -0,0 +1,171 @@ +""" +LangFuse monitoring integration for the GenAI Sales Analyst application. +""" +import uuid +from typing import Dict, Any, Optional, List +import os + + +class LangfuseMonitor: + """ + Manages LangFuse monitoring integration. + """ + + def __init__(self, public_key: str, secret_key: str, host: str = None): + """ + Initialize the LangFuse monitor. + + Args: + public_key: LangFuse public API key + secret_key: LangFuse secret API key + host: Optional LangFuse host URL + """ + try: + # Import langfuse + from langfuse import Langfuse + + # Initialize Langfuse client + langfuse_args = { + "public_key": public_key, + "secret_key": secret_key + } + + # Add host if provided + if host: + langfuse_args["host"] = host + + self.client = Langfuse(**langfuse_args) + self.enabled = True + print("LangFuse monitoring enabled") + except ImportError: + print("Langfuse package not installed. Monitoring will be disabled.") + self.enabled = False + except Exception as e: + print(f"Error initializing LangFuse: {str(e)}") + self.enabled = False + + def log_interaction(self, + prompt: str, + response: str, + metadata: Dict[str, Any], + trace_id: Optional[str] = None) -> Optional[str]: + """ + Log an interaction to LangFuse. + + Args: + prompt: The prompt sent to the model + response: The response from the model + metadata: Additional metadata about the interaction + trace_id: Optional trace ID to link related interactions (ignored in v3) + + Returns: + Trace ID or None if monitoring is disabled + """ + if not self.enabled: + return None + + try: + # Create a generation using context manager with completion included directly + with self.client.start_as_current_generation( + name=metadata.get("step_name", "model_call"), + model=metadata.get("model_id", "anthropic.claude-3-sonnet"), + prompt=prompt, + completion=response, # Include completion directly + metadata=metadata + ): + pass # No need to call end() since completion is included + + # Ensure data is sent + self.client.flush() + + # Get the current trace ID + current_trace_id = self.client.get_current_trace_id() + return current_trace_id + except Exception as e: + print(f"Error logging to LangFuse: {str(e)}") + return None + + def log_workflow(self, + workflow_name: str, + steps: List[Dict[str, Any]], + metadata: Dict[str, Any]) -> Optional[str]: + """ + Log a complete workflow execution to LangFuse. + + Args: + workflow_name: Name of the workflow + steps: List of workflow steps with their details + metadata: Additional metadata about the workflow + + Returns: + Trace ID or None if monitoring is disabled + """ + if not self.enabled: + return None + + try: + # Create a main workflow span + with self.client.start_as_current_span( + name=workflow_name, + metadata=metadata + ): + # Log each step as a span + for step in steps: + with self.client.start_as_current_span( + name=step.get("name", "unknown_step"), + metadata=step.get("metadata", {}) + ): + # Span is automatically ended when the context manager exits + pass + + # Ensure data is sent + self.client.flush() + + # Get the current trace ID + current_trace_id = self.client.get_current_trace_id() + return current_trace_id + except Exception as e: + print(f"Error logging workflow to LangFuse: {str(e)}") + return None + + def log_error(self, error_message: str, metadata: Dict[str, Any] = None) -> None: + """ + Log an error to LangFuse. + + Args: + error_message: Error message + metadata: Additional metadata about the error + """ + if not self.enabled: + return + + try: + if metadata is None: + metadata = {} + + # Create an error span + with self.client.start_as_current_span( + name="error", + metadata={ + **metadata, + "error_message": error_message, + "level": "error" + } + ): + # Span is automatically ended when the context manager exits + pass + + # Create a score for the error + trace_id = self.client.get_current_trace_id() + if trace_id: + self.client.create_score( + name="error", + value=0, # 0 indicates failure + trace_id=trace_id + ) + + # Ensure data is sent + self.client.flush() + except Exception as e: + print(f"Error logging error to LangFuse: {str(e)}") + return \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/prompts/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/prompts/__init__.py new file mode 100644 index 00000000..c078f395 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/prompts/__init__.py @@ -0,0 +1,3 @@ +""" +Prompt management module for the GenAI Sales Analyst application. +""" \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/prompts/prompt_template.py b/Sales-Analyst-Bedrock-Snowflake/src/prompts/prompt_template.py new file mode 100644 index 00000000..ec67d71a --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/prompts/prompt_template.py @@ -0,0 +1,57 @@ +""" +Prompt template management for the GenAI Sales Analyst application. +""" +import yaml +from typing import Dict, List, Any + + +class PromptTemplate: + """ + Manages prompt templates for the application. + """ + + def __init__(self, prompt_file: str = "src/prompts/prompts.yaml"): + """ + Initialize the prompt template manager. + + Args: + prompt_file: Path to the YAML file containing prompt templates + """ + with open(prompt_file, 'r') as f: + self.prompts = yaml.safe_load(f) + + def get_analysis_prompt(self, question: str, context: List[Dict[str, Any]]) -> str: + """ + Get the analysis prompt with the question and context. + + Args: + question: User's question + context: Relevant context information + + Returns: + Formatted prompt string + """ + base_prompt = self.prompts['analysis'] + context_str = "\n".join([f"- {c['text']}" for c in context]) + return base_prompt.format( + question=question, + context=context_str + ) + + def get_sql_prompt(self, question: str, context: List[Dict[str, Any]]) -> str: + """ + Get the SQL generation prompt with the question and context. + + Args: + question: User's question + context: Relevant context information + + Returns: + Formatted prompt string + """ + base_prompt = self.prompts['sql_generation'] + context_str = "\n".join([f"- {c['text']}" for c in context]) + return base_prompt.format( + question=question, + context=context_str + ) \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/prompts/prompts.yaml b/Sales-Analyst-Bedrock-Snowflake/src/prompts/prompts.yaml new file mode 100644 index 00000000..86077442 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/prompts/prompts.yaml @@ -0,0 +1,10 @@ +analysis: | + Given the following question: {question} + And this context about the data: {context} + Provide a detailed analysis of the data. + +sql_generation: | + Given the following question: {question} + And this context about the database schema: {context} + Generate a SQL query that will answer the question. + The query should be efficient and follow best practices. \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/ui/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/ui/__init__.py new file mode 100644 index 00000000..f4f10bc4 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/ui/__init__.py @@ -0,0 +1 @@ +# UI package \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/ui/components.py b/Sales-Analyst-Bedrock-Snowflake/src/ui/components.py new file mode 100644 index 00000000..f88cfc92 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/ui/components.py @@ -0,0 +1,246 @@ +""" +UI components for the GenAI Sales Analyst application. +""" +import streamlit as st +import pandas as pd +import altair as alt +from PIL import Image +import os +from ..config.settings import IMAGES_FOLDER + + +def display_header(): + """ + Display the application header with logo and title. + """ + # Load LendingTree logo if available + logo_path = os.path.join(IMAGES_FOLDER, "lendingtree_Logo.png") + if os.path.exists(logo_path): + lendingtree_logo = Image.open(logo_path) + st.image(lendingtree_logo, use_column_width=False, width=300) + + st.markdown('
', unsafe_allow_html=True) + st.markdown( + '

GenAI Sales Analyst (Powered by Amazon BedrockΒ©)

', + unsafe_allow_html=True, + ) + st.markdown('
', unsafe_allow_html=True) + + +def display_config_tab(get_snowflake_databases_fn, get_snowflake_schemas_fn, get_available_models_fn): + """ + Display the configuration tab. + + Args: + get_snowflake_databases_fn (function): Function to get Snowflake databases. + get_snowflake_schemas_fn (function): Function to get Snowflake schemas. + get_available_models_fn (function): Function to get available models. + """ + st.markdown( + """ +
+

πŸ”§ Configuration

+
+ """, + unsafe_allow_html=True, + ) + + databases = get_snowflake_databases_fn() + selected_db = st.selectbox( + "Select Database:", + databases, + index=databases.index(st.session_state.config["database"]) if databases and st.session_state.config["database"] in databases else 0 + ) if databases else None + + schemas = get_snowflake_schemas_fn(selected_db) if selected_db else [] + selected_schema = st.selectbox( + "Select Schema:", + schemas, + index=schemas.index(st.session_state.config["schema"]) if schemas and st.session_state.config["schema"] in schemas else 0 + ) if schemas else None + + available_models = get_available_models_fn() + selected_model = st.selectbox( + "Select AI Model:", + available_models, + index=available_models.index(st.session_state.config["model"]) if available_models and st.session_state.config["model"] in available_models else 0 + ) + + if st.button("Set Configuration", key="set_config"): + st.session_state.config.update({ + "database": selected_db, + "schema": selected_schema, + "model": selected_model + }) + st.success( + f"Configuration updated: Database = {selected_db}, Schema = {selected_schema}, Model = {selected_model}" + ) + + # Information banner at the bottom + st.markdown( + """ +
+

πŸ€– About This Tool

+

+ Welcome to your intelligent sales analysis assistant, powered by Amazon Bedrock Nova. + This tool helps you navigate and analyze sales data across multiple + Snowflake databases with ease. +

+

+ Simply use natural language to: +

+

+

+ No SQL knowledge required - just ask your questions naturally, and let the AI handle the complexity. +

+
+ """, + unsafe_allow_html=True, + ) + + +def display_analyst_tab(handle_user_query_fn, execute_multiple_sql_queries_fn): + """ + Display the sales analyst tab. + + Args: + handle_user_query_fn (function): Function to handle user queries. + execute_multiple_sql_queries_fn (function): Function to execute SQL queries. + """ + # Create two columns + col1, col2 = st.columns(2, gap="large") + + with col1: + st.markdown( + """ +
+

πŸ“ Sales Analyzer

+
+ """, + unsafe_allow_html=True, + ) + + nl_query = st.text_area( + "Enter your question:", + height=100, + placeholder="Ask anything about the sales data...", + key="query_input" + ) + + if st.button("Submit", key="submit_query") and nl_query.strip(): + db_name = st.session_state.config.get("database") + schema_name = st.session_state.config.get("schema") + model_id = st.session_state.config.get("model") + + if db_name and schema_name: + sql_queries, summary = handle_user_query_fn(nl_query, db_name, schema_name, model_id) + + if sql_queries: + results = execute_multiple_sql_queries_fn(sql_queries, database=db_name, schema=schema_name) + + if results: + # Store the query and result in session state + st.session_state.queries.append(nl_query) + st.session_state.history.append({"query": nl_query, "results": results}) + + for query, df in results.items(): + if df is not None: + st.markdown(f"**Query:** {query}") + st.dataframe(df) + + # Try to visualize the data if possible + if not df.empty and len(df.columns) >= 2: + try: + # Simple heuristic for visualization + if len(df) <= 20: # Small enough for a chart + numeric_cols = df.select_dtypes(include=['number']).columns + if len(numeric_cols) >= 1: + # Choose first string column for x-axis if available + string_cols = df.select_dtypes(include=['object']).columns + if len(string_cols) >= 1: + x_col = string_cols[0] + y_col = numeric_cols[0] + + chart = alt.Chart(df).mark_bar().encode( + x=x_col, + y=y_col, + tooltip=list(df.columns) + ).properties( + title=f"Visualization of {y_col} by {x_col}" + ) + st.altair_chart(chart, use_container_width=True) + except Exception as e: + pass # Silently fail if visualization isn't possible + else: + st.warning(f"No results for query: {query}") + else: + st.warning("No results returned for the queries.") + elif summary: + st.markdown("### Data Summary") + st.text(summary) + + # Store the query and summary in session state + st.session_state.queries.append(nl_query) + st.session_state.history.append({"query": nl_query, "summary": summary}) + else: + st.warning("No SQL queries or summary generated.") + + with col2: + st.markdown( + """ +
+

πŸ—‚ Historical Results

+
+ """, + unsafe_allow_html=True, + ) + + if "queries" in st.session_state and "history" in st.session_state: + if st.session_state.queries: + for entry in st.session_state.history: + query = entry["query"] + + st.markdown( + f""" +
+

Query: {query}

+
+ """, + unsafe_allow_html=True, + ) + + if "results" in entry: + result_set = entry["results"] + if isinstance(result_set, dict): + for sql_query, df in result_set.items(): + st.code(sql_query, language="sql") + if df is not None and not df.empty: + st.dataframe(df) + else: + st.warning(f"No results for query: {sql_query}") + else: + st.warning("Unexpected data format in historical results.") + elif "summary" in entry: + st.text(entry["summary"]) + else: + st.markdown("No historical queries found.") + else: + st.warning("Session state is missing query history.") + + +def display_exit_button(reset_app_fn): + """ + Display the exit button. + + Args: + reset_app_fn (function): Function to reset the application. + """ + st.markdown('
', unsafe_allow_html=True) + if st.button("Exit", key="exit_button"): + reset_app_fn() + st.markdown('
', unsafe_allow_html=True) \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/ui/styles.py b/Sales-Analyst-Bedrock-Snowflake/src/ui/styles.py new file mode 100644 index 00000000..c49e56ef --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/ui/styles.py @@ -0,0 +1,125 @@ +""" +CSS styles for the GenAI Sales Analyst application. +""" +import streamlit as st + + +def apply_custom_styles(): + """ + Apply custom CSS styles to the Streamlit application. + """ + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/utils/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/utils/__init__.py new file mode 100644 index 00000000..f50fc602 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/utils/__init__.py @@ -0,0 +1,3 @@ +""" +Utility functions for the GenAI Sales Analyst application. +""" \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/utils/bedrock_client.py b/Sales-Analyst-Bedrock-Snowflake/src/utils/bedrock_client.py new file mode 100644 index 00000000..b2247854 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/utils/bedrock_client.py @@ -0,0 +1,117 @@ +""" +Amazon Bedrock client utilities. +""" +import boto3 +import json +import streamlit as st +from ..config.settings import AWS_REGION, DEFAULT_MODEL_ID + + +def initialize_bedrock_clients(): + """ + Initialize Amazon Bedrock clients. + + Returns: + tuple: (bedrock_client, bedrock_runtime_client) + """ + try: + bedrock_client = boto3.client("bedrock", region_name=AWS_REGION) + bedrock_runtime_client = boto3.client("bedrock-runtime", region_name=AWS_REGION) + return bedrock_client, bedrock_runtime_client + except Exception as e: + st.error(f"Error initializing Bedrock clients: {e}") + return None, None + + +def get_available_models(): + """ + Fetch available models from Amazon Bedrock. + + Returns: + list: List of available model IDs. + """ + bedrock_client, _ = initialize_bedrock_clients() + try: + response = bedrock_client.list_foundation_models() + models = [model["modelId"] for model in response["modelSummaries"]] + return models + except Exception as e: + st.error(f"Error fetching models from Bedrock: {e}") + return [DEFAULT_MODEL_ID] + + +def invoke_bedrock_model(prompt, model_id=DEFAULT_MODEL_ID): + """ + Invoke Amazon Bedrock model with a prompt. + + Args: + prompt (str): The prompt to send to the model. + model_id (str, optional): The model ID to use. Defaults to DEFAULT_MODEL_ID. + + Returns: + dict: The model response. + """ + _, bedrock_runtime_client = initialize_bedrock_clients() + + payload = { + "messages": [{"role": "user", "content": [{"text": prompt}]}] + } + + try: + response = bedrock_runtime_client.invoke_model( + body=json.dumps(payload), + modelId=model_id, + contentType="application/json", + accept="application/json", + ) + result = json.loads(response["body"].read().decode("utf-8")) + return result + except Exception as e: + st.error(f"Error invoking Bedrock model: {e}") + return None + + +def suggest_chart_from_bedrock(df, model_id=DEFAULT_MODEL_ID): + """ + Use Amazon Bedrock to suggest a chart type based on the data. + + Args: + df (pandas.DataFrame): The DataFrame to analyze. + model_id (str, optional): The model ID to use. Defaults to DEFAULT_MODEL_ID. + + Returns: + str: The suggested chart type. + """ + if df.empty or len(df.columns) < 2: + return None # Skip if the dataframe is empty or has fewer than 2 columns + + # Prepare a summary of the DataFrame for Bedrock + column_summary = [ + { + "column_name": col, + "data_type": str(df[col].dtype), + "unique_values": df[col].nunique() + } + for col in df.columns + ] + + # Convert column summary into a properly formatted string + column_summary_str = json.dumps(column_summary, indent=2) + + # Prepare the prompt + prompt = ( + f"Based on the following data schema and sample:\n\n" + f"{column_summary_str}\n\n" + "Suggest a suitable chart type (e.g., bar, line, scatter) and columns to use for visualization. " + "Respond with 'none' if the data is not suitable for plotting." + ) + + result = invoke_bedrock_model(prompt, model_id) + + if result: + content = result.get("output", {}).get("message", {}).get("content", []) + if isinstance(content, list) and content: + suggestion = content[0].get("text", "").lower() + return suggestion + + return "none" \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/utils/helpers.py b/Sales-Analyst-Bedrock-Snowflake/src/utils/helpers.py new file mode 100644 index 00000000..4a4b5054 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/utils/helpers.py @@ -0,0 +1,46 @@ +""" +Helper functions for the GenAI Sales Analyst application. +""" +import pandas as pd +from typing import Dict, Any, List +import os +from dotenv import load_dotenv + + +def load_environment(): + """ + Load environment variables from .env file. + """ + load_dotenv() + + return { + 'aws_region': os.getenv('AWS_REGION', 'us-east-1'), + 's3_bucket': os.getenv('S3_BUCKET'), + 'langfuse_public_key': os.getenv('LANGFUSE_PUBLIC_KEY'), + 'langfuse_secret_key': os.getenv('LANGFUSE_SECRET_KEY') + } + + +def process_uploaded_data(df: pd.DataFrame, vector_store): + """ + Process and store uploaded data in vector store. + + Args: + df: DataFrame with metadata + vector_store: Vector store instance + + Returns: + Result message from saving the index + """ + texts = [] + metadatas = [] + + for _, row in df.iterrows(): + text = f"{row['column_name']}: {row['description']}" + texts.append(text) + metadatas.append(row.to_dict()) + + vector_store.add_texts(texts, metadatas) + save_result = vector_store.save_index() + + return save_result \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/utils/northwind_bootstrapper.py b/Sales-Analyst-Bedrock-Snowflake/src/utils/northwind_bootstrapper.py new file mode 100644 index 00000000..12c4f79f --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/utils/northwind_bootstrapper.py @@ -0,0 +1,510 @@ +""" +Northwind database bootstrapper for the GenAI Sales Analyst application. +""" +import os +import requests +import pandas as pd +import tempfile +import sqlite3 +import streamlit as st +import traceback +import snowflake.connector +from dotenv import load_dotenv +from .snowflake_connector import get_snowflake_connection + +# Load environment variables (override existing ones) +load_dotenv(override=True) + +DATABASE_NAME = "SALES_ANALYST" +NORTHWIND_SCHEMA = "NORTHWIND" +NORTHWIND_TABLES = ["CUSTOMERS", "PRODUCTS", "ORDERS", "ORDER_DETAILS", "CATEGORIES", "SUPPLIERS", "EMPLOYEES", "SHIPPERS"] +NORTHWIND_DATA_URL = "https://raw.githubusercontent.com/jpwhite3/northwind-SQLite3/master/northwind.db" + +def check_northwind_exists(): + """Check if Northwind schema and tables exist in Snowflake.""" + try: + # Use get_snowflake_connection for consistency + conn = get_snowflake_connection() + + cursor = conn.cursor() + + # Create database if it doesn't exist + cursor.execute(f"CREATE DATABASE IF NOT EXISTS {DATABASE_NAME}") + cursor.execute(f"USE DATABASE {DATABASE_NAME}") + print(f"Using database: {DATABASE_NAME}") + + # Check if schema exists + cursor.execute(f"SHOW SCHEMAS LIKE '{NORTHWIND_SCHEMA}'") + result = cursor.fetchall() + if not result: + print(f"Schema {NORTHWIND_SCHEMA} does not exist") + return False + + # Check if tables exist + for table in NORTHWIND_TABLES: + cursor.execute(f"SHOW TABLES LIKE '{table}' IN SCHEMA {NORTHWIND_SCHEMA}") + result = cursor.fetchall() + if not result: + print(f"Table {NORTHWIND_SCHEMA}.{table} does not exist") + return False + + # Check if data exists (sample count from ORDERS table) + cursor.execute(f"SELECT COUNT(*) AS COUNT FROM {NORTHWIND_SCHEMA}.ORDERS") + result = cursor.fetchone() + if not result or result[0] < 1: + print(f"No data in {NORTHWIND_SCHEMA}.ORDERS") + return False + + return True + except Exception as e: + print(f"Error checking if Northwind exists: {str(e)}") + return False + finally: + if 'conn' in locals(): + conn.close() + +# Rest of the file remains unchanged +def create_northwind_schema(): + """Create Northwind schema in Snowflake.""" + try: + conn = get_snowflake_connection() + cursor = conn.cursor() + + # Create database if it doesn't exist + cursor.execute(f"CREATE DATABASE IF NOT EXISTS {DATABASE_NAME}") + cursor.execute(f"USE DATABASE {DATABASE_NAME}") + print(f"Using database: {DATABASE_NAME}") + + # Now create the schema + cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {NORTHWIND_SCHEMA}") + print(f"Created schema {NORTHWIND_SCHEMA}") + conn.close() + return True + except Exception as e: + print(f"Error creating schema: {str(e)}") + traceback.print_exc() + return False + +def download_northwind_data(): + """Download Northwind SQLite database.""" + try: + temp_dir = tempfile.mkdtemp() + sqlite_path = os.path.join(temp_dir, "northwind.db") + + print(f"Downloading from {NORTHWIND_DATA_URL}") + # Download SQLite database with proper headers + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' + } + response = requests.get(NORTHWIND_DATA_URL, headers=headers, stream=True) + + if response.status_code != 200: + print(f"Failed to download: HTTP {response.status_code}") + # Try alternative URL + alt_url = "https://github.com/jpwhite3/northwind-SQLite3/raw/master/northwind.db" + print(f"Trying alternative URL: {alt_url}") + response = requests.get(alt_url, headers=headers, stream=True) + if response.status_code != 200: + print(f"Failed to download from alternative URL: HTTP {response.status_code}") + # Try another alternative URL + alt_url2 = "https://github.com/Microsoft/sql-server-samples/raw/master/samples/databases/northwind-pubs/instnwnd.sql" + print(f"Trying another alternative URL: {alt_url2}") + response = requests.get(alt_url2, headers=headers) + if response.status_code != 200: + print(f"Failed to download from all URLs: HTTP {response.status_code}") + return None + else: + # This is SQL, not SQLite, so we'll create a simple SQLite DB with the main tables + print("Creating sample Northwind data manually") + create_sample_northwind_data(sqlite_path) + return sqlite_path + + with open(sqlite_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + # Verify it's a valid SQLite database + try: + test_conn = sqlite3.connect(sqlite_path) + test_cursor = test_conn.cursor() + test_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1") + test_cursor.fetchone() + test_conn.close() + except sqlite3.DatabaseError: + print("Downloaded file is not a valid SQLite database") + print("Creating sample Northwind data manually") + create_sample_northwind_data(sqlite_path) + + print(f"Downloaded to {sqlite_path}") + return sqlite_path + except Exception as e: + print(f"Error downloading data: {str(e)}") + traceback.print_exc() + return None + +def create_sample_northwind_data(sqlite_path): + """Create a sample Northwind database with basic data.""" + conn = sqlite3.connect(sqlite_path) + cursor = conn.cursor() + + # Create tables + cursor.execute(''' + CREATE TABLE Customers ( + CustomerID TEXT PRIMARY KEY, + CompanyName TEXT, + ContactName TEXT, + ContactTitle TEXT, + Address TEXT, + City TEXT, + Region TEXT, + PostalCode TEXT, + Country TEXT, + Phone TEXT, + Fax TEXT + ) + ''') + + cursor.execute(''' + CREATE TABLE Products ( + ProductID INTEGER PRIMARY KEY, + ProductName TEXT, + SupplierID INTEGER, + CategoryID INTEGER, + QuantityPerUnit TEXT, + UnitPrice REAL, + UnitsInStock INTEGER, + UnitsOnOrder INTEGER, + ReorderLevel INTEGER, + Discontinued INTEGER + ) + ''') + + cursor.execute(''' + CREATE TABLE Orders ( + OrderID INTEGER PRIMARY KEY, + CustomerID TEXT, + EmployeeID INTEGER, + OrderDate TEXT, + RequiredDate TEXT, + ShippedDate TEXT, + ShipVia INTEGER, + Freight REAL, + ShipName TEXT, + ShipAddress TEXT, + ShipCity TEXT, + ShipRegion TEXT, + ShipPostalCode TEXT, + ShipCountry TEXT + ) + ''') + + cursor.execute(''' + CREATE TABLE Order_Details ( + OrderID INTEGER, + ProductID INTEGER, + UnitPrice REAL, + Quantity INTEGER, + Discount REAL, + PRIMARY KEY (OrderID, ProductID) + ) + ''') + + cursor.execute(''' + CREATE TABLE Categories ( + CategoryID INTEGER PRIMARY KEY, + CategoryName TEXT, + Description TEXT + ) + ''') + + cursor.execute(''' + CREATE TABLE Suppliers ( + SupplierID INTEGER PRIMARY KEY, + CompanyName TEXT, + ContactName TEXT, + ContactTitle TEXT, + Address TEXT, + City TEXT, + Region TEXT, + PostalCode TEXT, + Country TEXT, + Phone TEXT, + Fax TEXT, + HomePage TEXT + ) + ''') + + cursor.execute(''' + CREATE TABLE Employees ( + EmployeeID INTEGER PRIMARY KEY, + LastName TEXT, + FirstName TEXT, + Title TEXT, + TitleOfCourtesy TEXT, + BirthDate TEXT, + HireDate TEXT, + Address TEXT, + City TEXT, + Region TEXT, + PostalCode TEXT, + Country TEXT, + HomePhone TEXT, + Extension TEXT, + Notes TEXT, + ReportsTo INTEGER + ) + ''') + + cursor.execute(''' + CREATE TABLE Shippers ( + ShipperID INTEGER PRIMARY KEY, + CompanyName TEXT, + Phone TEXT + ) + ''') + + # Add sample data + # Customers + customers = [ + ('ALFKI', 'Alfreds Futterkiste', 'Maria Anders', 'Sales Representative', 'Obere Str. 57', 'Berlin', None, '12209', 'Germany', '030-0074321', '030-0076545'), + ('ANATR', 'Ana Trujillo Emparedados y helados', 'Ana Trujillo', 'Owner', 'Avda. de la ConstituciΓ³n 2222', 'MΓ©xico D.F.', None, '05021', 'Mexico', '(5) 555-4729', '(5) 555-3745'), + ('ANTON', 'Antonio Moreno TaquerΓ­a', 'Antonio Moreno', 'Owner', 'Mataderos 2312', 'MΓ©xico D.F.', None, '05023', 'Mexico', '(5) 555-3932', None) + ] + cursor.executemany('INSERT INTO Customers VALUES (?,?,?,?,?,?,?,?,?,?,?)', customers) + + # Products + products = [ + (1, 'Chai', 1, 1, '10 boxes x 20 bags', 18.0, 39, 0, 10, 0), + (2, 'Chang', 1, 1, '24 - 12 oz bottles', 19.0, 17, 40, 25, 0), + (3, 'Aniseed Syrup', 1, 2, '12 - 550 ml bottles', 10.0, 13, 70, 25, 0) + ] + cursor.executemany('INSERT INTO Products VALUES (?,?,?,?,?,?,?,?,?,?)', products) + + # Orders + orders = [ + (10248, 'VINET', 5, '1996-07-04', '1996-08-01', '1996-07-16', 3, 32.38, 'Vins et alcools Chevalier', '59 rue de l Abbaye', 'Reims', None, '51100', 'France'), + (10249, 'TOMSP', 6, '1996-07-05', '1996-08-16', '1996-07-10', 1, 11.61, 'Toms SpezialitΓ€ten', 'Luisenstr. 48', 'MΓΌnster', None, '44087', 'Germany'), + (10250, 'HANAR', 4, '1996-07-08', '1996-08-05', '1996-07-12', 2, 65.83, 'Hanari Carnes', 'Rua do Paco 67', 'Rio de Janeiro', 'RJ', '05454-876', 'Brazil') + ] + cursor.executemany('INSERT INTO Orders VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)', orders) + + # Order Details + order_details = [ + (10248, 11, 14.0, 12, 0), + (10248, 42, 9.8, 10, 0), + (10249, 14, 18.6, 9, 0), + (10249, 51, 42.4, 40, 0), + (10250, 41, 7.7, 10, 0), + (10250, 51, 42.4, 35, 0.15), + (10250, 65, 16.8, 15, 0.15) + ] + cursor.executemany('INSERT INTO Order_Details VALUES (?,?,?,?,?)', order_details) + + # Categories + categories = [ + (1, 'Beverages', 'Soft drinks coffees teas beers and ales'), + (2, 'Condiments', 'Sweet and savory sauces relishes spreads and seasonings'), + (3, 'Confections', 'Desserts candies and sweet breads') + ] + cursor.executemany('INSERT INTO Categories VALUES (?,?,?)', categories) + + # Suppliers + suppliers = [ + (1, 'Exotic Liquids', 'Charlotte Cooper', 'Purchasing Manager', '49 Gilbert St.', 'London', None, 'EC1 4SD', 'UK', '(171) 555-2222', None, None), + (2, 'New Orleans Cajun Delights', 'Shelley Burke', 'Order Administrator', 'P.O. Box 78934', 'New Orleans', 'LA', '70117', 'USA', '(100) 555-4822', None, None), + (3, 'Grandma Kellys Homestead', 'Regina Murphy', 'Sales Representative', '707 Oxford Rd.', 'Ann Arbor', 'MI', '48104', 'USA', '(313) 555-5735', '(313) 555-3349', None) + ] + cursor.executemany('INSERT INTO Suppliers VALUES (?,?,?,?,?,?,?,?,?,?,?,?)', suppliers) + + # Employees + employees = [ + (1, 'Davolio', 'Nancy', 'Sales Representative', 'Ms.', '1968-12-08', '1992-05-01', '507 - 20th Ave. E. Apt. 2A', 'Seattle', 'WA', '98122', 'USA', '(206) 555-9857', '5467', 'Education includes a BA in psychology from Colorado State University.', 2), + (2, 'Fuller', 'Andrew', 'Vice President Sales', 'Dr.', '1952-02-19', '1992-08-14', '908 W. Capital Way', 'Tacoma', 'WA', '98401', 'USA', '(206) 555-9482', '3457', 'Andrew received his BTS commercial and a Ph.D. in international marketing from the University of Dallas.', None), + (3, 'Leverling', 'Janet', 'Sales Representative', 'Ms.', '1963-08-30', '1992-04-01', '722 Moss Bay Blvd.', 'Kirkland', 'WA', '98033', 'USA', '(206) 555-3412', '3355', 'Janet has a BS degree in chemistry from Boston College.', 2) + ] + cursor.executemany('INSERT INTO Employees VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)', employees) + + # Shippers + shippers = [ + (1, 'Speedy Express', '(503) 555-9831'), + (2, 'United Package', '(503) 555-3199'), + (3, 'Federal Shipping', '(503) 555-9931') + ] + cursor.executemany('INSERT INTO Shippers VALUES (?,?,?)', shippers) + + conn.commit() + conn.close() + +def extract_data_from_sqlite(sqlite_path): + """Extract data from SQLite database into pandas DataFrames.""" + try: + conn = sqlite3.connect(sqlite_path) + + # Dictionary to hold all tables' data + tables_data = {} + + # Get all tables + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + + print(f"Found tables in SQLite: {tables}") + + # Extract data from each table + for table in tables: + print(f"Extracting data from {table}") + tables_data[table] = pd.read_sql_query(f"SELECT * FROM {table}", conn) + + conn.close() + return tables_data + except Exception as e: + print(f"Error extracting data: {str(e)}") + traceback.print_exc() + return {} + +def get_create_table_ddl(table_name, df): + """Generate CREATE TABLE DDL from pandas DataFrame.""" + try: + # Map pandas dtypes to Snowflake data types + dtype_map = { + 'int64': 'NUMBER', + 'float64': 'FLOAT', + 'object': 'VARCHAR(255)', + 'datetime64[ns]': 'TIMESTAMP_NTZ', + 'bool': 'BOOLEAN' + } + + columns = [] + for col, dtype in df.dtypes.items(): + sf_type = dtype_map.get(str(dtype), 'VARCHAR(255)') + columns.append(f'"{col}" {sf_type}') + + ddl = f"CREATE OR REPLACE TABLE {NORTHWIND_SCHEMA}.{table_name.upper()} (\n" + ddl += ",\n".join(columns) + ddl += "\n)" + + return ddl + except Exception as e: + print(f"Error generating DDL for {table_name}: {str(e)}") + return None + +def load_data_to_snowflake(tables_data): + """Load data into Snowflake tables.""" + conn = get_snowflake_connection() + + try: + cursor = conn.cursor() + + # Use warehouse first + warehouse = os.getenv('SNOWFLAKE_WAREHOUSE', 'COMPUTE_WH') + cursor.execute(f"USE WAREHOUSE {warehouse}") + print(f"Using warehouse: {warehouse}") + + # Create database if it doesn't exist + cursor.execute(f"CREATE DATABASE IF NOT EXISTS {DATABASE_NAME}") + cursor.execute(f"USE DATABASE {DATABASE_NAME}") + print(f"Using database: {DATABASE_NAME}") + + # Create schema if it doesn't exist + cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {NORTHWIND_SCHEMA}") + cursor.execute(f"USE SCHEMA {NORTHWIND_SCHEMA}") + print(f"Using schema: {NORTHWIND_SCHEMA}") + + for table_name, df in tables_data.items(): + # Skip sqlite_sequence table + if table_name.lower() == 'sqlite_sequence': + continue + + print(f"Creating table {table_name.upper()}") + # Create table + create_table_sql = get_create_table_ddl(table_name.upper(), df) + if not create_table_sql: + print(f"Skipping table {table_name} due to DDL generation error") + continue + + cursor.execute(create_table_sql) + + # Load data using direct INSERT statements + print(f"Loading data into {table_name.upper()}") + + # Convert DataFrame to list of tuples, handling NaN properly + data_tuples = [] + for _, row in df.iterrows(): + row_tuple = tuple(None if pd.isna(val) else val for val in row) + data_tuples.append(row_tuple) + + # Create INSERT statement + placeholders = ','.join(['%s'] * len(df.columns)) + columns = ','.join([f'"{col}"' for col in df.columns]) + insert_sql = f"INSERT INTO {NORTHWIND_SCHEMA}.{table_name.upper()} ({columns}) VALUES ({placeholders})" + + # Execute batch insert + cursor.executemany(insert_sql, data_tuples) + print(f"Inserted {len(data_tuples)} rows into {table_name.upper()}") + return True + except Exception as e: + print(f"Error loading data to Snowflake: {str(e)}") + traceback.print_exc() + return False + finally: + conn.close() + +def bootstrap_northwind(show_progress=False): + """Bootstrap Northwind database in Snowflake if it doesn't exist.""" + if check_northwind_exists(): + if show_progress: + st.success("βœ… Northwind database already exists in Snowflake.") + return True + + if show_progress: + st.info("πŸ”„ Bootstrapping Northwind database...") + progress_bar = st.progress(0) + + try: + # Create schema + schema_result = create_northwind_schema() + if not schema_result: + if show_progress: + st.error("❌ Failed to create Northwind schema.") + return False + + if show_progress: + progress_bar.progress(0.1, text="Created schema...") + + # Download data + if show_progress: + progress_bar.progress(0.2, text="Downloading data...") + sqlite_path = download_northwind_data() + if not sqlite_path: + if show_progress: + st.error("❌ Failed to download Northwind data.") + return False + + # Extract data + if show_progress: + progress_bar.progress(0.4, text="Extracting data...") + tables_data = extract_data_from_sqlite(sqlite_path) + if not tables_data: + if show_progress: + st.error("❌ Failed to extract data from SQLite.") + return False + + # Load data to Snowflake + if show_progress: + progress_bar.progress(0.6, text="Loading data to Snowflake...") + load_result = load_data_to_snowflake(tables_data) + if not load_result: + if show_progress: + st.error("❌ Failed to load data to Snowflake.") + return False + + if show_progress: + progress_bar.progress(1.0, text="Completed!") + st.success("βœ… Northwind database successfully bootstrapped.") + return True + except Exception as e: + if show_progress: + st.error(f"❌ Error bootstrapping Northwind database: {str(e)}") + print(f"Error bootstrapping Northwind: {str(e)}") + traceback.print_exc() + return False \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/utils/query_processor.py b/Sales-Analyst-Bedrock-Snowflake/src/utils/query_processor.py new file mode 100644 index 00000000..dfdd4969 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/utils/query_processor.py @@ -0,0 +1,349 @@ +""" +Query processing utilities. +""" +import re +import streamlit as st +import pandas as pd +from cachetools import TTLCache +from .snowflake_connector import get_detailed_schema_info +from .bedrock_client import invoke_bedrock_model +from ..config.settings import SCHEMA_CACHE_TTL, SCHEMA_CACHE_SIZE, DEFAULT_MODEL_ID + +# Cache for schema information +schema_cache = TTLCache(maxsize=SCHEMA_CACHE_SIZE, ttl=SCHEMA_CACHE_TTL) + + +def generate_schema_context(database, schema): + """ + Generate a natural language description of the schema. + + Args: + database (str): Database name. + schema (str): Schema name. + + Returns: + str: Natural language description of the schema. + """ + schema_info = get_detailed_schema_info(database, schema) + + context = f"Database '{database}' with schema '{schema}' contains the following structure:\n\n" + + for table, columns in schema_info.items(): + context += f"Table '{table}' contains:\n" + for col_name, col_info in columns.items(): + context += f"- Column '{col_name}' of type {col_info['data_type']}" + + # Add sample values if available + if col_info['sample_values']: + context += f" (example values: {', '.join(col_info['sample_values'])})" + + # Add column comment if available + if col_info['comment']: + context += f" - {col_info['comment']}" + + context += "\n" + context += "\n" + + return context + + +def get_cached_schema_context(database, schema): + """ + Get schema context with caching. + + Args: + database (str): Database name. + schema (str): Schema name. + + Returns: + str: Natural language description of the schema. + """ + cache_key = f"{database}_{schema}" + if cache_key not in schema_cache: + schema_cache[cache_key] = generate_schema_context(database, schema) + return schema_cache[cache_key] + + +def extract_sql_from_response(response_json): + """ + Extract SQL queries from the Bedrock response. + + Args: + response_json (dict): The Bedrock response. + + Returns: + list: List of SQL queries. + """ + try: + # Extract SQL text from response + if isinstance(response_json, dict): + content = response_json.get("output", {}).get("message", {}).get("content", []) + if isinstance(content, list) and content: + sql_text = content[0].get("text", "").strip() + else: + raise ValueError("Empty or invalid Bedrock response content.") + elif isinstance(response_json, list): + sql_text = response_json[0].strip() + else: + raise ValueError("Unexpected response structure.") + + # Remove Markdown formatting + if sql_text.startswith("```sql"): + sql_text = sql_text[6:].strip() + if sql_text.endswith("```"): + sql_text = sql_text[:-3].strip() + + # Fix Snowflake syntax: replace single quotes with double quotes for identifiers + sql_text = re.sub(r"'([^']*)'\\.'([^']*)'", r'"\\1"."\\2"', sql_text) + sql_text = re.sub(r"'([^']*)'", r'"\\1"', sql_text) + + # Handle common NLP queries for table listings + if re.search(r"show\\s+tables", sql_text, re.IGNORECASE): + return ["SHOW TABLES"] + + # Handle common NLP queries for record previews + preview_match = re.search(r"(top|first)\\s+(\\d+)", sql_text, re.IGNORECASE) + if preview_match and "select" in sql_text.lower(): + limit_num = preview_match.group(2) + if "limit" not in sql_text.lower(): + sql_text = f"{sql_text} LIMIT {limit_num}" + + # Fix "SHOW TABLES" query syntax + sql_text = re.sub(r"SHOW TABLES IN DATABASE (\\S+)\\.(\\S+);?", r"SHOW TABLES IN SCHEMA \\1.\\2;", sql_text) + + # Split multiple SQL queries + sql_queries = [query.strip() for query in sql_text.split(";") if query.strip()] + return sql_queries + + except Exception as e: + st.error(f"Error extracting SQL queries: {e}") + return [] + + +def generate_sql_query(nl_query, db_name, schema_name, model_id=DEFAULT_MODEL_ID): + """ + Generate SQL query from natural language query. + + Args: + nl_query (str): Natural language query. + db_name (str): Database name. + schema_name (str): Schema name. + model_id (str, optional): The model ID to use. Defaults to DEFAULT_MODEL_ID. + + Returns: + list: List of SQL queries. + """ + schema_context = get_cached_schema_context(db_name, schema_name) + schema_context = schema_context[:5000] if len(schema_context) > 5000 else schema_context + + message_content = ( + f"You are an expert SQL generator for Snowflake.\n" + f"Schema Information:\n{schema_context}\n\n" + f"Query: {nl_query}\n\n" + f"Generate a valid SQL query. Respond with only the query." + ) + + result = invoke_bedrock_model(message_content, model_id) + + if result: + sql_query = extract_sql_from_response(result) + if not sql_query: + st.warning("No valid query generated. Using fallback query.") + sql_query = fallback_sql_query(nl_query) + return sql_query + else: + return fallback_sql_query(nl_query) + + +def fallback_sql_query(nl_query): + """ + Generate fallback SQL query for common cases. + + Args: + nl_query (str): Natural language query. + + Returns: + list: List of SQL queries. + """ + # Simplified fallback logic for common cases + if "list tables" in nl_query.lower(): + return ["SHOW TABLES;"] + elif "sample records" in nl_query.lower(): + return ["SELECT * FROM LIMIT 5;"] + elif "highest number of sales orders" in nl_query.lower(): + return [""" + SELECT N.N_NAME, COUNT(O.O_ORDERKEY) AS ORDER_COUNT + FROM SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.NATION N + JOIN SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.CUSTOMER C ON N.N_NATIONKEY = C.C_NATIONKEY + JOIN SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.ORDERS O ON C.C_CUSTKEY = O.O_CUSTKEY + GROUP BY N.N_NAME + ORDER BY ORDER_COUNT DESC + LIMIT 1; + """] + else: + return [] + + +def find_relevant_tables(schema_info, topic): + """ + Find tables relevant to a specific topic based on table and column names. + + Args: + schema_info (dict): Schema information. + topic (str): Topic to search for. + + Returns: + list: List of relevant table names. + """ + relevant_tables = [] + topic_keywords = set(topic.lower().split()) + + for table_name, columns in schema_info.items(): + # Check table name + if any(keyword in table_name.lower() for keyword in topic_keywords): + relevant_tables.append(table_name) + continue + + # Check column names and comments + for col_name, col_info in columns.items(): + if any(keyword in col_name.lower() for keyword in topic_keywords): + relevant_tables.append(table_name) + break + if col_info['comment'] and any(keyword in col_info['comment'].lower() for keyword in topic_keywords): + relevant_tables.append(table_name) + break + + return list(set(relevant_tables)) + + +def generate_data_summary(database, schema, topic, limit=1000): + """ + Generate a comprehensive data summary for a specific topic. + + Args: + database (str): Database name. + schema (str): Schema name. + topic (str): Topic to summarize. + limit (int, optional): Maximum number of records to analyze. Defaults to 1000. + + Returns: + str: Data summary. + """ + from .snowflake_connector import get_detailed_schema_info, analyze_table_relationships, connect_to_snowflake + + try: + # Get schema information + schema_info = get_detailed_schema_info(database, schema) + + # Only try to get relationships if not using sample data + if not database.startswith('SNOWFLAKE_SAMPLE'): + relationships = analyze_table_relationships(database, schema) + else: + relationships = {} + + relevant_tables = find_relevant_tables(schema_info, topic) + + if not relevant_tables: + return f"No relevant tables found for topic: {topic}" + + summaries = [] + conn = connect_to_snowflake() + cursor = conn.cursor() + + for table_name in relevant_tables: + columns = schema_info[table_name] + # Identify key metrics columns (numeric types) + metric_columns = [col for col, info in columns.items() + if info['data_type'].upper() in ('NUMBER', 'FLOAT', 'INTEGER', 'DECIMAL')] + + # Identify categorical columns + categorical_columns = [col for col, info in columns.items() + if info['data_type'].upper() in ('VARCHAR', 'STRING', 'CHAR')] + + # Generate summary query + summary_query = f""" + SELECT + COUNT(*) as total_records + {', ' + ', '.join(f'COUNT(DISTINCT {col}) as unique_{col}' for col in categorical_columns[:5]) if categorical_columns else ''} + {', ' + ', '.join(f'AVG({col}) as avg_{col}, MAX({col}) as max_{col}, MIN({col}) as min_{col}' + for col in metric_columns[:5]) if metric_columns else ''} + FROM {database}.{schema}.{table_name} + """ + + # Get sample records + sample_query = f""" + SELECT * + FROM {database}.{schema}.{table_name} + LIMIT 5 + """ + + try: + # Execute summary query + cursor.execute(summary_query) + summary_results = cursor.fetchone() + + # Execute sample query + cursor.execute(sample_query) + sample_results = cursor.fetchall() + sample_columns = [desc[0] for desc in cursor.description] + + # Format summary + table_summary = f"\n=== Summary for {table_name} ===\n" + table_summary += f"Total Records: {summary_results[0]}\n" + + # Add categorical summaries + col_index = 1 + for col in categorical_columns[:5]: + if col_index < len(summary_results): + table_summary += f"Unique {col}: {summary_results[col_index]}\n" + col_index += 1 + + # Add metric summaries + for col in metric_columns[:5]: + if col_index + 2 < len(summary_results): + avg_val = summary_results[col_index] + max_val = summary_results[col_index + 1] + min_val = summary_results[col_index + 2] + table_summary += f"{col} - Avg: {avg_val:.2f}, Max: {max_val}, Min: {min_val}\n" + col_index += 3 + + # Add sample records + table_summary += "\nSample Records:\n" + sample_df = pd.DataFrame(sample_results, columns=sample_columns) + table_summary += sample_df.to_string() + + summaries.append(table_summary) + + except Exception as e: + st.error(f"Error generating summary for {table_name}: {e}") + continue + + return "\n\n".join(summaries) + + except Exception as e: + st.error(f"Error generating data summary: {e}") + return f"Error generating summary: {str(e)}" + + +def handle_user_query(nl_query, db_name, schema_name, model_id=DEFAULT_MODEL_ID): + """ + Handle both SQL queries and summary requests. + + Args: + nl_query (str): Natural language query. + db_name (str): Database name. + schema_name (str): Schema name. + model_id (str, optional): The model ID to use. Defaults to DEFAULT_MODEL_ID. + + Returns: + tuple: (sql_query, summary) + """ + if any(keyword in nl_query.lower() for keyword in ['summarize', 'summary', 'overview', 'analyze']): + # Extract the topic from the query + topic = nl_query.lower().replace('summarize', '').replace('summary', '').replace('overview', '').replace('analyze', '').strip() + summary = generate_data_summary(db_name, schema_name, topic) + return None, summary # Return (sql_query, summary) + else: + # Generate SQL query as before + sql_query = generate_sql_query(nl_query, db_name, schema_name, model_id) + return sql_query, None # Return (sql_query, summary) \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/utils/setup_utils.py b/Sales-Analyst-Bedrock-Snowflake/src/utils/setup_utils.py new file mode 100644 index 00000000..2182e422 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/utils/setup_utils.py @@ -0,0 +1,147 @@ +""" +Setup utilities for the GenAI Sales Analyst application. +""" +import boto3 +import logging +from botocore.exceptions import ClientError +import os +import yaml +from typing import Dict + + +class SetupManager: + def __init__(self, region_name: str = 'us-east-1'): + self.region = region_name + self.s3 = boto3.client('s3', region_name=region_name) + self.bedrock = boto3.client('bedrock-runtime', region_name=region_name) + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger(__name__) + + def create_bucket(self, bucket_name: str) -> bool: + """Create an S3 bucket for metadata storage""" + try: + if self.region == 'us-east-1': + self.s3.create_bucket(Bucket=bucket_name) + else: + self.s3.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={ + 'LocationConstraint': self.region + } + ) + + # Enable versioning + self.s3.put_bucket_versioning( + Bucket=bucket_name, + VersioningConfiguration={'Status': 'Enabled'} + ) + + # Create folders + folders = ['metadata', 'vector_store', 'models'] + for folder in folders: + self.s3.put_object( + Bucket=bucket_name, + Key=f'{folder}/' + ) + + self.logger.info(f"Successfully created bucket: {bucket_name}") + return True + + except ClientError as e: + if e.response['Error']['Code'] == 'BucketAlreadyOwnedByYou': + self.logger.info(f"Bucket {bucket_name} already exists") + return True + elif e.response['Error']['Code'] == 'BucketAlreadyExists': + self.logger.error(f"Bucket {bucket_name} already exists but is owned by another account") + return False + else: + self.logger.error(f"Error creating bucket: {e}") + return False + + def verify_bedrock_access(self) -> bool: + """Verify Bedrock access and available models""" + try: + response = self.bedrock.list_foundation_models() + models = [model['modelId'] for model in response['modelSummaries']] + required_models = [ + "anthropic.claude-3-sonnet-20240229-v1:0", + "amazon.titan-embed-text-v1" + ] + + for model in required_models: + if model not in models: + self.logger.warning(f"Required model {model} not available") + return False + + self.logger.info("Bedrock access verified") + return True + + except Exception as e: + self.logger.error(f"Error verifying Bedrock access: {e}") + return False + + def create_streamlit_secrets(self, config: Dict) -> bool: + """Create Streamlit secrets file""" + try: + os.makedirs('.streamlit', exist_ok=True) + with open('.streamlit/secrets.toml', 'w') as f: + yaml.dump(config, f) + self.logger.info("Created Streamlit secrets file") + return True + except Exception as e: + self.logger.error(f"Error creating Streamlit secrets: {e}") + return False + + +def run_setup(): + """Run complete setup process""" + setup = SetupManager() + + print("Starting setup process...") + + # 1. Get configuration + config = { + 'aws': { + 'region': input("Enter AWS region (default: us-east-1): ") or 'us-east-1', + 's3_bucket': input("Enter unique S3 bucket name: "), + }, + 'langfuse': { + 'public_key': input("Enter LangFuse public key (optional): "), + 'secret_key': input("Enter LangFuse secret key (optional): "), + } + } + + # 2. Create S3 bucket + print("\nCreating S3 bucket...") + if setup.create_bucket(config['aws']['s3_bucket']): + print("βœ“ S3 bucket created successfully") + else: + print("βœ— Failed to create S3 bucket") + return + + # 3. Verify Bedrock access + print("\nVerifying Bedrock access...") + if setup.verify_bedrock_access(): + print("βœ“ Bedrock access verified") + else: + print("βœ— Bedrock access verification failed") + return + + # 4. Create Streamlit secrets + print("\nCreating Streamlit secrets...") + if setup.create_streamlit_secrets(config): + print("βœ“ Streamlit secrets created") + else: + print("βœ— Failed to create Streamlit secrets") + return + + print("\nSetup completed successfully!") + print(f"Your application is ready to use with bucket: {config['aws']['s3_bucket']}") + print("\nNext steps:") + print("1. Run: streamlit run app.py") + print("2. Upload your metadata through the web interface") + print("3. Start asking questions!") + + +if __name__ == "__main__": + run_setup() \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/utils/snowflake_connector.py b/Sales-Analyst-Bedrock-Snowflake/src/utils/snowflake_connector.py new file mode 100644 index 00000000..2fcbca17 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/utils/snowflake_connector.py @@ -0,0 +1,201 @@ +""" +Snowflake connector for the GenAI Sales Analyst application. +""" +import os +import snowflake.connector +from dotenv import load_dotenv + +# Load environment variables (override existing ones) +load_dotenv(override=True) + +def get_snowflake_connection(): + """ + Get a connection to Snowflake. + + Returns: + Snowflake connection object + """ + # Get credentials from environment variables or update these placeholders + user = os.getenv('SNOWFLAKE_USER', '') + password = os.getenv('SNOWFLAKE_PASSWORD', '') + account = os.getenv('SNOWFLAKE_ACCOUNT', '') + warehouse = os.getenv('SNOWFLAKE_WAREHOUSE', 'COMPUTE_WH') + role = os.getenv('SNOWFLAKE_ROLE', 'ACCOUNTADMIN') + + + + # Connect to Snowflake + conn = snowflake.connector.connect( + user=user, + password=password, + account=account, + warehouse=warehouse, + role=role + ) + + return conn + +def execute_query(query): + """ + Execute a SQL query on Snowflake. + + Args: + query: SQL query to execute + + Returns: + List of dictionaries with query results + """ + conn = get_snowflake_connection() + try: + cursor = conn.cursor() + + # Split the query into multiple statements if needed + statements = query.split(';') + statements = [stmt.strip() for stmt in statements if stmt.strip()] + + # Execute all statements except the last one without fetching results + for stmt in statements[:-1]: + cursor.execute(stmt) + + # Execute the last statement and fetch results + if statements: + cursor.execute(statements[-1]) + results = cursor.fetchall() + + # Convert to list of dictionaries + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + return [dict(zip(columns, row)) for row in results] + return [] + finally: + conn.close() + +def get_available_databases(): + """ + Get a list of available databases. + + Returns: + List of database names + """ + conn = get_snowflake_connection() + try: + cursor = conn.cursor() + cursor.execute("SHOW DATABASES") + results = cursor.fetchall() + return [row[1] for row in results] # Column 1 contains the database name + finally: + conn.close() + +def get_available_schemas(database): + """ + Get a list of available schemas in a database. + + Args: + database: Database name + + Returns: + List of schema names + """ + conn = get_snowflake_connection() + try: + cursor = conn.cursor() + cursor.execute(f"SHOW SCHEMAS IN DATABASE {database}") + results = cursor.fetchall() + return [row[1] for row in results] # Column 1 contains the schema name + finally: + conn.close() + +def get_available_tables(database, schema): + """ + Get a list of available tables in a schema. + + Args: + database: Database name + schema: Schema name + + Returns: + List of table names + """ + conn = get_snowflake_connection() + try: + cursor = conn.cursor() + cursor.execute(f"SHOW TABLES IN {database}.{schema}") + results = cursor.fetchall() + return [row[1] for row in results] # Column 1 contains the table name + finally: + conn.close() + +def get_table_columns(database, schema, table): + """ + Get a list of columns in a table. + + Args: + database: Database name + schema: Schema name + table: Table name + + Returns: + DataFrame with column information + """ + import pandas as pd + + conn = get_snowflake_connection() + try: + cursor = conn.cursor() + + # Use the database and schema + if database: + cursor.execute(f"USE DATABASE {database}") + if schema: + cursor.execute(f"USE SCHEMA {schema}") + + # Get column information + cursor.execute(f"DESCRIBE TABLE {table}") + results = cursor.fetchall() + + # Create DataFrame - handle variable number of columns from DESCRIBE TABLE + if results: + # Get actual number of columns returned + num_cols = len(results[0]) if results else 0 + if num_cols >= 11: + columns = ['name', 'type', 'kind', 'null', 'default', 'primary_key', 'unique_key', 'check', 'expression', 'comment', 'policy_name'] + if num_cols > 11: + # Add extra columns if they exist + for i in range(11, num_cols): + columns.append(f'extra_col_{i}') + else: + # Fallback for fewer columns + columns = [f'col_{i}' for i in range(num_cols)] + columns[0] = 'name' + columns[1] = 'type' if num_cols > 1 else 'name' + + df = pd.DataFrame(results, columns=columns) + else: + df = pd.DataFrame() + + # Rename columns to match expected format + rename_dict = {} + if 'name' in df.columns: + rename_dict['name'] = 'column_name' + if 'type' in df.columns: + rename_dict['type'] = 'data_type' + if 'comment' in df.columns: + rename_dict['comment'] = 'description' + + if rename_dict: + df = df.rename(columns=rename_dict) + + # Ensure required columns exist + if 'column_name' not in df.columns and len(df.columns) > 0: + df['column_name'] = df.iloc[:, 0] + if 'data_type' not in df.columns and len(df.columns) > 1: + df['data_type'] = df.iloc[:, 1] + if 'description' not in df.columns: + df['description'] = '' + + # Add table name + if not df.empty: + df['table_name'] = table + + return df + finally: + conn.close() \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/vector_store/__init__.py b/Sales-Analyst-Bedrock-Snowflake/src/vector_store/__init__.py new file mode 100644 index 00000000..d9c930e1 --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/vector_store/__init__.py @@ -0,0 +1,3 @@ +""" +Vector store module for the GenAI Sales Analyst application. +""" \ No newline at end of file diff --git a/Sales-Analyst-Bedrock-Snowflake/src/vector_store/faiss_manager.py b/Sales-Analyst-Bedrock-Snowflake/src/vector_store/faiss_manager.py new file mode 100644 index 00000000..de1cc2cf --- /dev/null +++ b/Sales-Analyst-Bedrock-Snowflake/src/vector_store/faiss_manager.py @@ -0,0 +1,138 @@ +""" +FAISS vector store manager for the GenAI Sales Analyst application. +""" +import faiss +import numpy as np +import pickle +import boto3 +from datetime import datetime +from typing import List, Dict, Any, Optional + + +class FAISSManager: + """ + Manages FAISS vector store operations. + """ + + def __init__(self, bedrock_client, s3_bucket: Optional[str] = None, dimension: int = 1536): + """ + Initialize the FAISS manager. + + Args: + bedrock_client: Client for Amazon Bedrock API + s3_bucket: S3 bucket name for storing indices + dimension: Dimension of the embedding vectors + """ + self.bedrock_client = bedrock_client + self.s3_bucket = s3_bucket + self.index = faiss.IndexFlatL2(dimension) + self.texts = [] + self.metadata = [] + + def add_texts(self, texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None): + """ + Add texts and their embeddings to the vector store. + + Args: + texts: List of text strings to add + metadatas: Optional list of metadata dictionaries + """ + if metadatas is None: + metadatas = [{} for _ in texts] + + embeddings = [] + for text in texts: + embedding = self.bedrock_client.get_embeddings(text) + embeddings.append(embedding) + + embeddings_array = np.array(embeddings).astype('float32') + self.index.add(embeddings_array) + self.texts.extend(texts) + self.metadata.extend(metadatas) + + def similarity_search(self, query: str, k: int = 4) -> List[Dict[str, Any]]: + """ + Search for similar texts based on the query. + + Args: + query: Query text + k: Number of results to return + + Returns: + List of dictionaries containing text, metadata, and distance + """ + # Handle empty index + if len(self.texts) == 0: + return [] + + try: + # Get query embedding + query_embedding = self.bedrock_client.get_embeddings(query) + query_array = np.array([query_embedding]).astype('float32') + + # Limit k to the number of items in the index + k = min(k, len(self.texts)) + if k == 0: + return [] + + # Search + distances, indices = self.index.search(query_array, k) + + results = [] + for i, idx in enumerate(indices[0]): + if idx < len(self.texts) and idx >= 0: + results.append({ + 'text': self.texts[idx], + 'metadata': self.metadata[idx], + 'distance': float(distances[0][i]) + }) + return results + except Exception as e: + print(f"Error in similarity search: {str(e)}") + # Return empty results on error + return [] + + def save_index(self) -> str: + """ + Save the index and data to S3. + + Returns: + Message indicating where the index was saved + """ + if not self.s3_bucket: + return "No S3 bucket specified, index not saved" + + s3 = boto3.client('s3') + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + + index_bytes = faiss.serialize_index(self.index) + index_key = f'vector_store/index_{timestamp}.faiss' + s3.put_object(Bucket=self.s3_bucket, Key=index_key, Body=index_bytes) + + data = {'texts': self.texts, 'metadata': self.metadata} + data_key = f'vector_store/data_{timestamp}.pkl' + s3.put_object(Bucket=self.s3_bucket, Key=data_key, Body=pickle.dumps(data)) + + return f"Index saved: {index_key}, Data saved: {data_key}" + + def load_index(self, index_key: str, data_key: str): + """ + Load the index and data from S3. + + Args: + index_key: S3 key for the index file + data_key: S3 key for the data file + """ + if not self.s3_bucket: + raise ValueError("No S3 bucket specified") + + s3 = boto3.client('s3') + + index_response = s3.get_object(Bucket=self.s3_bucket, Key=index_key) + index_bytes = index_response['Body'].read() + self.index = faiss.deserialize_index(index_bytes) + + data_response = s3.get_object(Bucket=self.s3_bucket, Key=data_key) + data = pickle.loads(data_response['Body'].read()) + self.texts = data['texts'] + self.metadata = data['metadata'] \ No newline at end of file