-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tests: add tests for config, smart dataframe and smart datalake
- Loading branch information
Showing
5 changed files
with
299 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
123 changes: 123 additions & 0 deletions
123
tests/unit_tests/smart_dataframe/test_smart_dataframe.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import warnings | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from pandasai.config import Config | ||
from pandasai.llm.fake import FakeLLM | ||
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes | ||
|
||
|
||
def test_smart_dataframe_init_basic(): | ||
# Create a sample dataframe | ||
df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
|
||
# Test initialization with minimal parameters | ||
with pytest.warns(DeprecationWarning): | ||
smart_df = SmartDataframe(df) | ||
|
||
assert smart_df._original_import is df | ||
assert isinstance(smart_df.dataframe, pd.DataFrame) | ||
assert smart_df._table_name is None | ||
assert smart_df._table_description is None | ||
assert smart_df._custom_head is None | ||
|
||
|
||
def test_smart_dataframe_init_with_all_params(): | ||
# Create sample dataframes | ||
df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
custom_head = pd.DataFrame({"A": [1], "B": ["x"]}) | ||
config = Config(llm=FakeLLM()) | ||
|
||
# Test initialization with all parameters | ||
with pytest.warns(DeprecationWarning): | ||
smart_df = SmartDataframe( | ||
df, | ||
name="test_df", | ||
description="Test dataframe", | ||
custom_head=custom_head, | ||
config=config, | ||
) | ||
|
||
assert smart_df._original_import is df | ||
assert isinstance(smart_df.dataframe, pd.DataFrame) | ||
assert smart_df._table_name == "test_df" | ||
assert smart_df._table_description == "Test dataframe" | ||
assert smart_df._custom_head == custom_head.to_csv(index=False) | ||
assert smart_df._agent._state._config == config | ||
|
||
|
||
def test_smart_dataframe_deprecation_warning(): | ||
df = pd.DataFrame({"A": [1, 2, 3]}) | ||
|
||
with warnings.catch_warnings(record=True) as warning_info: | ||
warnings.simplefilter("always") | ||
SmartDataframe(df) | ||
|
||
deprecation_warnings = [ | ||
w for w in warning_info if issubclass(w.category, DeprecationWarning) | ||
] | ||
assert len(deprecation_warnings) >= 1 | ||
assert "SmartDataframe will soon be deprecated" in str( | ||
deprecation_warnings[0].message | ||
) | ||
|
||
|
||
def test_load_df_success(): | ||
# Create sample dataframes | ||
original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
with pytest.warns(DeprecationWarning): | ||
smart_df = SmartDataframe(original_df) | ||
|
||
# Test loading a new dataframe | ||
new_df = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]}) | ||
loaded_df = smart_df.load_df( | ||
new_df, | ||
name="new_df", | ||
description="New test dataframe", | ||
custom_head=pd.DataFrame({"C": [4], "D": ["a"]}), | ||
) | ||
|
||
assert isinstance(loaded_df, pd.DataFrame) | ||
assert loaded_df.equals(new_df) | ||
|
||
|
||
def test_load_df_invalid_input(): | ||
# Create a sample dataframe | ||
original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
with pytest.warns(DeprecationWarning): | ||
smart_df = SmartDataframe(original_df) | ||
|
||
# Test loading invalid data | ||
with pytest.raises( | ||
ValueError, match="Invalid input data. We cannot convert it to a dataframe." | ||
): | ||
smart_df.load_df( | ||
"not a dataframe", | ||
name="invalid_df", | ||
description="Invalid test data", | ||
custom_head=None, | ||
) | ||
|
||
|
||
def test_load_smartdataframes(): | ||
# Create sample dataframes | ||
df1 = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
df2 = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]}) | ||
|
||
# Create a config with FakeLLM | ||
config = Config(llm=FakeLLM()) | ||
|
||
# Test loading regular pandas DataFrames | ||
smart_dfs = load_smartdataframes([df1, df2], config) | ||
assert len(smart_dfs) == 2 | ||
assert all(isinstance(df, SmartDataframe) for df in smart_dfs) | ||
assert all(hasattr(df, "config") for df in smart_dfs) | ||
|
||
# Test loading mixed pandas DataFrames and SmartDataframes | ||
existing_smart_df = SmartDataframe(df1, config=config) | ||
mixed_dfs = load_smartdataframes([existing_smart_df, df2], config) | ||
assert len(mixed_dfs) == 2 | ||
assert mixed_dfs[0] is existing_smart_df # Should return the same instance | ||
assert isinstance(mixed_dfs[1], SmartDataframe) | ||
assert hasattr(mixed_dfs[1], "config") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from unittest.mock import Mock | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from pandasai.config import Config | ||
from pandasai.smart_datalake import SmartDatalake | ||
|
||
|
||
@pytest.fixture | ||
def sample_dataframes(): | ||
df1 = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) | ||
df2 = pd.DataFrame({"C": [7, 8, 9], "D": [10, 11, 12]}) | ||
return [df1, df2] | ||
|
||
|
||
def test_dfs_property(sample_dataframes): | ||
# Create a mock agent with context | ||
mock_agent = Mock() | ||
mock_agent.context.dfs = sample_dataframes | ||
|
||
# Create SmartDatalake instance | ||
smart_datalake = SmartDatalake(sample_dataframes) | ||
smart_datalake._agent = mock_agent # Inject mock agent | ||
|
||
# Test that dfs property returns the correct dataframes | ||
assert smart_datalake.dfs == sample_dataframes | ||
|
||
|
||
def test_enable_cache(sample_dataframes): | ||
# Create a mock agent with context and config | ||
mock_config = Config(enable_cache=True) | ||
mock_agent = Mock() | ||
mock_agent.context.config = mock_config | ||
|
||
# Create SmartDatalake instance | ||
smart_datalake = SmartDatalake(sample_dataframes) | ||
smart_datalake._agent = mock_agent # Inject mock agent | ||
|
||
# Test that enable_cache property returns the correct value | ||
assert smart_datalake.enable_cache is True | ||
|
||
# Test with cache disabled | ||
mock_config.enable_cache = False | ||
assert smart_datalake.enable_cache is False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from pandasai.config import APIKeyManager | ||
|
||
|
||
def test_set_api_key(): | ||
# Setup | ||
test_api_key = "test-api-key-123" | ||
|
||
# Execute | ||
with patch.dict(os.environ, {}, clear=True): | ||
APIKeyManager.set(test_api_key) | ||
|
||
# Assert | ||
assert os.environ.get("PANDABI_API_KEY") == test_api_key | ||
assert APIKeyManager._api_key == test_api_key | ||
|
||
|
||
def test_get_api_key(): | ||
# Setup | ||
test_api_key = "test-api-key-123" | ||
APIKeyManager._api_key = test_api_key | ||
|
||
# Execute | ||
result = APIKeyManager.get() | ||
|
||
# Assert | ||
assert result == test_api_key | ||
|
||
|
||
def test_get_api_key_when_none(): | ||
# Setup | ||
APIKeyManager._api_key = None | ||
|
||
# Execute | ||
result = APIKeyManager.get() | ||
|
||
# Assert | ||
assert result is None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import os | ||
from importlib.util import find_spec | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from pandasai.config import APIKeyManager, Config, ConfigManager | ||
from pandasai.llm.bamboo_llm import BambooLLM | ||
|
||
|
||
class TestConfigManager: | ||
def setup_method(self): | ||
# Reset the ConfigManager state before each test | ||
ConfigManager._config = None | ||
ConfigManager._initialized = False | ||
|
||
def test_validate_llm_with_pandabi_api_key(self): | ||
"""Test validate_llm when PANDABI_API_KEY is set""" | ||
with patch.dict(os.environ, {"PANDABI_API_KEY": "test-key"}): | ||
ConfigManager._config = MagicMock() | ||
ConfigManager._config.llm = None | ||
|
||
ConfigManager.validate_llm() | ||
|
||
assert isinstance(ConfigManager._config.llm, BambooLLM) | ||
|
||
def test_validate_llm_without_pandabi_api_key(self): | ||
"""Test validate_llm when PANDABI_API_KEY is not set""" | ||
with patch.dict(os.environ, {}, clear=True): | ||
ConfigManager._config = MagicMock() | ||
ConfigManager._config.llm = None | ||
|
||
ConfigManager.validate_llm() | ||
|
||
assert ConfigManager._config.llm is None | ||
|
||
@pytest.mark.skipif( | ||
find_spec("pandasai_langchain") is None, | ||
reason="pandasai_langchain not installed", | ||
) | ||
def test_validate_llm_with_langchain(self): | ||
"""Test validate_llm with langchain integration""" | ||
from pandasai_langchain.langchain import LangchainLLM | ||
|
||
mock_langchain_llm = MagicMock() | ||
ConfigManager._config = MagicMock() | ||
ConfigManager._config.llm = mock_langchain_llm | ||
|
||
with patch("pandasai_langchain.langchain.is_langchain_llm", return_value=True): | ||
ConfigManager.validate_llm() | ||
|
||
assert isinstance(ConfigManager._config.llm, LangchainLLM) | ||
assert ConfigManager._config.llm._llm == mock_langchain_llm | ||
|
||
def test_update_config(self): | ||
"""Test updating configuration with new values""" | ||
# Initialize config with some initial values | ||
initial_config = {"save_logs": True, "verbose": False} | ||
ConfigManager._config = Config.from_dict(initial_config) | ||
|
||
# Update with new values | ||
update_dict = {"verbose": True, "enable_cache": False} | ||
ConfigManager.update(update_dict) | ||
|
||
# Verify the configuration was updated correctly | ||
updated_config = ConfigManager._config.model_dump() | ||
assert updated_config["save_logs"] is True # Original value preserved | ||
assert updated_config["verbose"] is True # Value updated | ||
assert updated_config["enable_cache"] is False # New value added | ||
|
||
def test_set_api_key(self): | ||
"""Test setting the API key""" | ||
test_api_key = "test-api-key-123" | ||
|
||
# Clear any existing API key | ||
if "PANDABI_API_KEY" in os.environ: | ||
del os.environ["PANDABI_API_KEY"] | ||
APIKeyManager._api_key = None | ||
|
||
# Set the API key | ||
APIKeyManager.set(test_api_key) | ||
|
||
# Verify the API key is set in both places | ||
assert os.environ["PANDABI_API_KEY"] == test_api_key | ||
assert APIKeyManager._api_key == test_api_key | ||
assert APIKeyManager.get() == test_api_key # Also test the get method |