-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tests: add tests for config, smart dataframe and smart datalake
- Loading branch information
Showing
5 changed files
with
341 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
tests/unit_tests/smart_dataframe/test_smart_dataframe.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import warnings | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from pandasai.config import Config | ||
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes | ||
|
||
|
||
def test_smart_dataframe_init_basic(): | ||
# Create a sample dataframe | ||
df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
|
||
# Test initialization with minimal parameters | ||
with pytest.warns(DeprecationWarning): | ||
smart_df = SmartDataframe(df) | ||
|
||
assert smart_df._original_import is df | ||
assert isinstance(smart_df.dataframe, pd.DataFrame) | ||
assert smart_df._table_name is None | ||
assert smart_df._table_description is None | ||
assert smart_df._custom_head is None | ||
|
||
|
||
def test_smart_dataframe_init_with_all_params(): | ||
# Create sample dataframes | ||
df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
custom_head = pd.DataFrame({"A": [1], "B": ["x"]}) | ||
config = Config() | ||
|
||
# Test initialization with all parameters | ||
with pytest.warns(DeprecationWarning): | ||
smart_df = SmartDataframe( | ||
df, | ||
name="test_df", | ||
description="Test dataframe", | ||
custom_head=custom_head, | ||
config=config, | ||
) | ||
|
||
assert smart_df._original_import is df | ||
assert isinstance(smart_df.dataframe, pd.DataFrame) | ||
assert smart_df._table_name == "test_df" | ||
assert smart_df._table_description == "Test dataframe" | ||
assert smart_df._custom_head == custom_head.to_csv(index=False) | ||
assert smart_df._agent._state._config == config | ||
|
||
|
||
def test_smart_dataframe_deprecation_warning(): | ||
df = pd.DataFrame({"A": [1, 2, 3]}) | ||
|
||
with warnings.catch_warnings(record=True) as warning_info: | ||
warnings.simplefilter("always") | ||
SmartDataframe(df) | ||
|
||
deprecation_warnings = [ | ||
w for w in warning_info if issubclass(w.category, DeprecationWarning) | ||
] | ||
assert len(deprecation_warnings) >= 1 | ||
assert "SmartDataframe will soon be deprecated" in str( | ||
deprecation_warnings[0].message | ||
) | ||
|
||
|
||
def test_load_df_success(): | ||
# Create sample dataframes | ||
original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
with pytest.warns(DeprecationWarning): | ||
smart_df = SmartDataframe(original_df) | ||
|
||
# Test loading a new dataframe | ||
new_df = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]}) | ||
loaded_df = smart_df.load_df( | ||
new_df, | ||
name="new_df", | ||
description="New test dataframe", | ||
custom_head=pd.DataFrame({"C": [4], "D": ["a"]}), | ||
) | ||
|
||
assert isinstance(loaded_df, pd.DataFrame) | ||
assert loaded_df.equals(new_df) | ||
|
||
|
||
def test_load_df_invalid_input(): | ||
# Create a sample dataframe | ||
original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
with pytest.warns(DeprecationWarning): | ||
smart_df = SmartDataframe(original_df) | ||
|
||
# Test loading invalid data | ||
with pytest.raises( | ||
ValueError, match="Invalid input data. We cannot convert it to a dataframe." | ||
): | ||
smart_df.load_df( | ||
"not a dataframe", | ||
name="invalid_df", | ||
description="Invalid test data", | ||
custom_head=None, | ||
) | ||
|
||
|
||
def test_load_smartdataframes(): | ||
# Create sample dataframes | ||
df1 = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) | ||
df2 = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]}) | ||
|
||
# Create a config | ||
config = Config() | ||
|
||
# Test loading regular pandas DataFrames | ||
smart_dfs = load_smartdataframes([df1, df2], config) | ||
assert len(smart_dfs) == 2 | ||
assert all(isinstance(df, SmartDataframe) for df in smart_dfs) | ||
assert all(hasattr(df, "config") for df in smart_dfs) | ||
|
||
# Test loading mixed pandas DataFrames and SmartDataframes | ||
existing_smart_df = SmartDataframe(df1, config=config) | ||
mixed_dfs = load_smartdataframes([existing_smart_df, df2], config) | ||
assert len(mixed_dfs) == 2 | ||
assert mixed_dfs[0] is existing_smart_df # Should return the same instance | ||
assert isinstance(mixed_dfs[1], SmartDataframe) | ||
assert hasattr(mixed_dfs[1], "config") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from unittest.mock import Mock | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from pandasai.config import Config | ||
from pandasai.smart_datalake import SmartDatalake | ||
|
||
|
||
@pytest.fixture | ||
def sample_dataframes(): | ||
df1 = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) | ||
df2 = pd.DataFrame({"C": [7, 8, 9], "D": [10, 11, 12]}) | ||
return [df1, df2] | ||
|
||
|
||
def test_dfs_property(sample_dataframes): | ||
# Create a mock agent with context | ||
mock_agent = Mock() | ||
mock_agent.context.dfs = sample_dataframes | ||
|
||
# Create SmartDatalake instance | ||
smart_datalake = SmartDatalake(sample_dataframes) | ||
smart_datalake._agent = mock_agent # Inject mock agent | ||
|
||
# Test that dfs property returns the correct dataframes | ||
assert smart_datalake.dfs == sample_dataframes | ||
|
||
|
||
def test_enable_cache(sample_dataframes): | ||
# Create a mock agent with context and config | ||
mock_config = Config(enable_cache=True) | ||
mock_agent = Mock() | ||
mock_agent.context.config = mock_config | ||
|
||
# Create SmartDatalake instance | ||
smart_datalake = SmartDatalake(sample_dataframes) | ||
smart_datalake._agent = mock_agent # Inject mock agent | ||
|
||
# Test that enable_cache property returns the correct value | ||
assert smart_datalake.enable_cache is True | ||
|
||
# Test with cache disabled | ||
mock_config.enable_cache = False | ||
assert smart_datalake.enable_cache is False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from pandasai.config import APIKeyManager | ||
|
||
|
||
def test_set_api_key(): | ||
# Setup | ||
test_api_key = "test-api-key-123" | ||
|
||
# Execute | ||
with patch.dict(os.environ, {}, clear=True): | ||
APIKeyManager.set(test_api_key) | ||
|
||
# Assert | ||
assert os.environ.get("PANDABI_API_KEY") == test_api_key | ||
assert APIKeyManager._api_key == test_api_key | ||
|
||
|
||
def test_get_api_key(): | ||
# Setup | ||
test_api_key = "test-api-key-123" | ||
APIKeyManager._api_key = test_api_key | ||
|
||
# Execute | ||
result = APIKeyManager.get() | ||
|
||
# Assert | ||
assert result == test_api_key | ||
|
||
|
||
def test_get_api_key_when_none(): | ||
# Setup | ||
APIKeyManager._api_key = None | ||
|
||
# Execute | ||
result = APIKeyManager.get() | ||
|
||
# Assert | ||
assert result is None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import os | ||
from unittest.mock import MagicMock, patch | ||
|
||
from pandasai.config import Config, ConfigManager | ||
from pandasai.helpers.filemanager import DefaultFileManager | ||
from pandasai.llm.bamboo_llm import BambooLLM | ||
|
||
|
||
def test_validate_llm_with_pandabi_api_key(): | ||
# Setup | ||
ConfigManager._config = MagicMock() | ||
ConfigManager._config.llm = None | ||
|
||
with patch.dict(os.environ, {"PANDABI_API_KEY": "test-key"}): | ||
# Execute | ||
ConfigManager.validate_llm() | ||
|
||
# Assert | ||
assert isinstance(ConfigManager._config.llm, BambooLLM) | ||
|
||
|
||
def test_validate_llm_with_langchain(): | ||
# Setup | ||
ConfigManager._config = MagicMock() | ||
mock_llm = MagicMock() | ||
ConfigManager._config.llm = mock_llm | ||
mock_langchain_llm = MagicMock() | ||
|
||
# Create mock module | ||
mock_langchain_module = MagicMock() | ||
mock_langchain_module.__spec__ = MagicMock() | ||
mock_langchain_module.langchain = MagicMock( | ||
LangchainLLM=mock_langchain_llm, is_langchain_llm=lambda x: True | ||
) | ||
|
||
with patch.dict( | ||
"sys.modules", | ||
{ | ||
"pandasai_langchain": mock_langchain_module, | ||
"pandasai_langchain.langchain": mock_langchain_module.langchain, | ||
}, | ||
): | ||
# Execute | ||
ConfigManager.validate_llm() | ||
|
||
# Assert | ||
assert mock_langchain_llm.call_count == 1 | ||
|
||
|
||
def test_validate_llm_no_action_needed(): | ||
# Setup | ||
ConfigManager._config = MagicMock() | ||
mock_llm = MagicMock() | ||
ConfigManager._config.llm = mock_llm | ||
|
||
# Case where no PANDABI_API_KEY and not a langchain LLM | ||
with patch.dict(os.environ, {}, clear=True): | ||
with patch("importlib.util.find_spec") as mock_find_spec: | ||
mock_find_spec.return_value = None | ||
|
||
# Execute | ||
ConfigManager.validate_llm() | ||
|
||
# Assert - llm should remain unchanged | ||
assert ConfigManager._config.llm == mock_llm | ||
|
||
|
||
def test_config_update(): | ||
# Setup | ||
mock_config = MagicMock() | ||
initial_config = {"key1": "value1", "key2": "value2"} | ||
mock_config.model_dump = MagicMock(return_value=initial_config.copy()) | ||
ConfigManager._config = mock_config | ||
|
||
# Create a mock for Config.from_dict | ||
original_from_dict = Config.from_dict | ||
Config.from_dict = MagicMock() | ||
|
||
try: | ||
# Execute | ||
new_config = {"key2": "new_value2", "key3": "value3"} | ||
ConfigManager.update(new_config) | ||
|
||
# Assert | ||
expected_config = {"key1": "value1", "key2": "new_value2", "key3": "value3"} | ||
assert mock_config.model_dump.call_count == 1 | ||
Config.from_dict.assert_called_once_with(expected_config) | ||
finally: | ||
# Restore original from_dict method | ||
Config.from_dict = original_from_dict | ||
|
||
|
||
def test_config_set(): | ||
# Setup | ||
test_config = {"key": "value"} | ||
with patch.object(Config, "from_dict") as mock_from_dict, patch.object( | ||
ConfigManager, "validate_llm" | ||
) as mock_validate_llm: | ||
# Execute | ||
ConfigManager.set(test_config) | ||
|
||
# Assert | ||
mock_from_dict.assert_called_once_with(test_config) | ||
mock_validate_llm.assert_called_once() | ||
|
||
|
||
def test_config_from_dict(): | ||
# Test with default overrides | ||
config_dict = { | ||
"save_logs": False, | ||
"verbose": True, | ||
"enable_cache": False, | ||
"max_retries": 5, | ||
} | ||
|
||
config = Config.from_dict(config_dict) | ||
|
||
assert isinstance(config, Config) | ||
assert config.save_logs == False | ||
assert config.verbose == True | ||
assert config.enable_cache == False | ||
assert config.max_retries == 5 | ||
assert config.llm is None | ||
assert isinstance(config.file_manager, DefaultFileManager) | ||
|
||
# Test with minimal dict | ||
minimal_config = Config.from_dict({}) | ||
assert isinstance(minimal_config, Config) | ||
assert minimal_config.save_logs == True # default value | ||
assert minimal_config.verbose == False # default value | ||
assert minimal_config.enable_cache == True # default value | ||
assert minimal_config.max_retries == 3 # default value |