From f6673670ddde7c0e2531f0a103e5dd7d73330b7d Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Fri, 31 Jan 2025 20:08:45 +0100 Subject: [PATCH] fix(sql): allow paramerized query through sql sanitization (#1576) --- .../connectors/sql/pandasai_sql/__init__.py | 19 +++++++++++-- pandasai/data_loader/sql_loader.py | 6 ++++ pandasai/helpers/sql_sanitizer.py | 8 +++++- .../unit_tests/data_loader/test_sql_loader.py | 28 +++++++++++++++++++ .../unit_tests/helpers/test_sql_sanitizer.py | 4 +++ 5 files changed, 61 insertions(+), 4 deletions(-) diff --git a/extensions/connectors/sql/pandasai_sql/__init__.py b/extensions/connectors/sql/pandasai_sql/__init__.py index 3cd517985..d01393dd5 100644 --- a/extensions/connectors/sql/pandasai_sql/__init__.py +++ b/extensions/connectors/sql/pandasai_sql/__init__.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional import pandas as pd @@ -17,7 +18,11 @@ def load_from_mysql( database=connection_info.database, port=connection_info.port, ) - return pd.read_sql(query, conn, params=params) + # Suppress warnings of SqlAlchemy + # TODO - Later can be removed when SqlAlchemy is to used + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + return pd.read_sql(query, conn, params=params) def load_from_postgres( @@ -32,7 +37,11 @@ def load_from_postgres( dbname=connection_info.database, port=connection_info.port, ) - return pd.read_sql(query, conn, params=params) + # Suppress warnings of SqlAlchemy + # TODO - Later can be removed when SqlAlchemy is to used + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + return pd.read_sql(query, conn, params=params) def load_from_cockroachdb( @@ -47,7 +56,11 @@ def load_from_cockroachdb( dbname=connection_info.database, port=connection_info.port, ) - return pd.read_sql(query, conn, params=params) + # Suppress warnings of SqlAlchemy + # TODO - Later can be removed when SqlAlchemy is to used + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + return pd.read_sql(query, conn, params=params) __all__ = [ diff --git a/pandasai/data_loader/sql_loader.py b/pandasai/data_loader/sql_loader.py index a116f36b8..920c43fd3 100644 --- a/pandasai/data_loader/sql_loader.py +++ b/pandasai/data_loader/sql_loader.py @@ -47,6 +47,12 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra connection_info, formatted_query, params ) return self._apply_transformations(dataframe) + + except ModuleNotFoundError as e: + raise ImportError( + f"{source_type.capitalize()} connector not found. Please install the pandasai_sql[{source_type}] library, e.g. `pip install pandasai_sql[{source_type}]`." + ) from e + except Exception as e: raise RuntimeError( f"Failed to execute query for '{source_type}' with: {formatted_query}" diff --git a/pandasai/helpers/sql_sanitizer.py b/pandasai/helpers/sql_sanitizer.py index fb908d063..52c59346b 100644 --- a/pandasai/helpers/sql_sanitizer.py +++ b/pandasai/helpers/sql_sanitizer.py @@ -58,8 +58,14 @@ def is_sql_query_safe(query: str) -> bool: r"--", r"/\*.*\*/", # Block comments and inline comments ] + + placeholder = "___PLACEHOLDER___" # Temporary placeholder for params + + # Replace '%s' (MySQL, Psycopg2) with a unique placeholder + temp_query = query.replace("%s", placeholder) + # Parse the query to extract its structure - parsed = sqlglot.parse_one(query) + parsed = sqlglot.parse_one(temp_query) # Ensure the main query is SELECT if parsed.key.upper() != "SELECT": diff --git a/tests/unit_tests/data_loader/test_sql_loader.py b/tests/unit_tests/data_loader/test_sql_loader.py index b46354eae..2a9bbc277 100644 --- a/tests/unit_tests/data_loader/test_sql_loader.py +++ b/tests/unit_tests/data_loader/test_sql_loader.py @@ -197,3 +197,31 @@ def test_mysql_safe_query(self, mysql_schema): assert isinstance(result, DataFrame) mock_sql_query.assert_called_once_with("select * from users") + + def test_mysql_malicious_with_no_import(self, mysql_schema): + """Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly.""" + with patch( + "pandasai.data_loader.sql_loader.is_sql_query_safe" + ) as mock_sql_query, patch( + "pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function" + ) as mock_loader_function: + mocked_exec_function = MagicMock() + mock_df = DataFrame( + pd.DataFrame( + { + "email": ["test@example.com"], + "first_name": ["John"], + "timestamp": [pd.Timestamp.now()], + } + ) + ) + mocked_exec_function.return_value = mock_df + + mock_exec_function = MagicMock() + mock_loader_function.return_value = mock_exec_function + mock_exec_function.side_effect = ModuleNotFoundError("Error") + loader = SQLDatasetLoader(mysql_schema, "test/users") + mock_sql_query.return_value = True + logging.debug("Loading schema from dataset path: %s", loader) + with pytest.raises(ImportError): + loader.execute_query("select * from users") diff --git a/tests/unit_tests/helpers/test_sql_sanitizer.py b/tests/unit_tests/helpers/test_sql_sanitizer.py index a572cc5f4..fbeab50de 100644 --- a/tests/unit_tests/helpers/test_sql_sanitizer.py +++ b/tests/unit_tests/helpers/test_sql_sanitizer.py @@ -82,6 +82,10 @@ def test_safe_query_with_subquery(self): query ) # Safe query with subquery, no dangerous keyword + def test_safe_query_with_query_params(self): + query = "SELECT * FROM (SELECT * FROM heart_data) AS filtered_data LIMIT %s OFFSET %s" + assert is_sql_query_safe(query) + if __name__ == "__main__": unittest.main()