Skip to content

Commit

Permalink
tests: add tests for config, smart dataframe and smart datalake
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Feb 3, 2025
1 parent 936c6f0 commit b9782e9
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 1 deletion.
1 change: 0 additions & 1 deletion pandasai/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import Any, Dict, Optional

Expand Down
122 changes: 122 additions & 0 deletions tests/unit_tests/smart_dataframe/test_smart_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import warnings

import pandas as pd
import pytest

from pandasai.config import Config
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes


def test_smart_dataframe_init_basic():
# Create a sample dataframe
df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})

# Test initialization with minimal parameters
with pytest.warns(DeprecationWarning):
smart_df = SmartDataframe(df)

assert smart_df._original_import is df
assert isinstance(smart_df.dataframe, pd.DataFrame)
assert smart_df._table_name is None
assert smart_df._table_description is None
assert smart_df._custom_head is None


def test_smart_dataframe_init_with_all_params():
# Create sample dataframes
df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})
custom_head = pd.DataFrame({"A": [1], "B": ["x"]})
config = Config()

# Test initialization with all parameters
with pytest.warns(DeprecationWarning):
smart_df = SmartDataframe(
df,
name="test_df",
description="Test dataframe",
custom_head=custom_head,
config=config,
)

assert smart_df._original_import is df
assert isinstance(smart_df.dataframe, pd.DataFrame)
assert smart_df._table_name == "test_df"
assert smart_df._table_description == "Test dataframe"
assert smart_df._custom_head == custom_head.to_csv(index=False)
assert smart_df._agent._state._config == config


def test_smart_dataframe_deprecation_warning():
df = pd.DataFrame({"A": [1, 2, 3]})

with warnings.catch_warnings(record=True) as warning_info:
warnings.simplefilter("always")
SmartDataframe(df)

deprecation_warnings = [
w for w in warning_info if issubclass(w.category, DeprecationWarning)
]
assert len(deprecation_warnings) >= 1
assert "SmartDataframe will soon be deprecated" in str(
deprecation_warnings[0].message
)


def test_load_df_success():
# Create sample dataframes
original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})
with pytest.warns(DeprecationWarning):
smart_df = SmartDataframe(original_df)

# Test loading a new dataframe
new_df = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]})
loaded_df = smart_df.load_df(
new_df,
name="new_df",
description="New test dataframe",
custom_head=pd.DataFrame({"C": [4], "D": ["a"]}),
)

assert isinstance(loaded_df, pd.DataFrame)
assert loaded_df.equals(new_df)


def test_load_df_invalid_input():
# Create a sample dataframe
original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})
with pytest.warns(DeprecationWarning):
smart_df = SmartDataframe(original_df)

# Test loading invalid data
with pytest.raises(
ValueError, match="Invalid input data. We cannot convert it to a dataframe."
):
smart_df.load_df(
"not a dataframe",
name="invalid_df",
description="Invalid test data",
custom_head=None,
)


def test_load_smartdataframes():
# Create sample dataframes
df1 = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})
df2 = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]})

# Create a config
config = Config()

# Test loading regular pandas DataFrames
smart_dfs = load_smartdataframes([df1, df2], config)
assert len(smart_dfs) == 2
assert all(isinstance(df, SmartDataframe) for df in smart_dfs)
assert all(hasattr(df, "config") for df in smart_dfs)

# Test loading mixed pandas DataFrames and SmartDataframes
existing_smart_df = SmartDataframe(df1, config=config)
mixed_dfs = load_smartdataframes([existing_smart_df, df2], config)
assert len(mixed_dfs) == 2
assert mixed_dfs[0] is existing_smart_df # Should return the same instance
assert isinstance(mixed_dfs[1], SmartDataframe)
assert hasattr(mixed_dfs[1], "config")
45 changes: 45 additions & 0 deletions tests/unit_tests/smart_datalake/test_smart_datalake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from unittest.mock import Mock

import pandas as pd
import pytest

from pandasai.config import Config
from pandasai.smart_datalake import SmartDatalake


@pytest.fixture
def sample_dataframes():
df1 = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
df2 = pd.DataFrame({"C": [7, 8, 9], "D": [10, 11, 12]})
return [df1, df2]


def test_dfs_property(sample_dataframes):
# Create a mock agent with context
mock_agent = Mock()
mock_agent.context.dfs = sample_dataframes

# Create SmartDatalake instance
smart_datalake = SmartDatalake(sample_dataframes)
smart_datalake._agent = mock_agent # Inject mock agent

# Test that dfs property returns the correct dataframes
assert smart_datalake.dfs == sample_dataframes


def test_enable_cache(sample_dataframes):
# Create a mock agent with context and config
mock_config = Config(enable_cache=True)
mock_agent = Mock()
mock_agent.context.config = mock_config

# Create SmartDatalake instance
smart_datalake = SmartDatalake(sample_dataframes)
smart_datalake._agent = mock_agent # Inject mock agent

# Test that enable_cache property returns the correct value
assert smart_datalake.enable_cache is True

# Test with cache disabled
mock_config.enable_cache = False
assert smart_datalake.enable_cache is False
42 changes: 42 additions & 0 deletions tests/unit_tests/test_api_key_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from unittest.mock import patch

import pytest

from pandasai.config import APIKeyManager


def test_set_api_key():
# Setup
test_api_key = "test-api-key-123"

# Execute
with patch.dict(os.environ, {}, clear=True):
APIKeyManager.set(test_api_key)

# Assert
assert os.environ.get("PANDABI_API_KEY") == test_api_key
assert APIKeyManager._api_key == test_api_key


def test_get_api_key():
# Setup
test_api_key = "test-api-key-123"
APIKeyManager._api_key = test_api_key

# Execute
result = APIKeyManager.get()

# Assert
assert result == test_api_key


def test_get_api_key_when_none():
# Setup
APIKeyManager._api_key = None

# Execute
result = APIKeyManager.get()

# Assert
assert result is None
132 changes: 132 additions & 0 deletions tests/unit_tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
from unittest.mock import MagicMock, patch

from pandasai.config import Config, ConfigManager
from pandasai.helpers.filemanager import DefaultFileManager
from pandasai.llm.bamboo_llm import BambooLLM


def test_validate_llm_with_pandabi_api_key():
# Setup
ConfigManager._config = MagicMock()
ConfigManager._config.llm = None

with patch.dict(os.environ, {"PANDABI_API_KEY": "test-key"}):
# Execute
ConfigManager.validate_llm()

# Assert
assert isinstance(ConfigManager._config.llm, BambooLLM)


def test_validate_llm_with_langchain():
# Setup
ConfigManager._config = MagicMock()
mock_llm = MagicMock()
ConfigManager._config.llm = mock_llm
mock_langchain_llm = MagicMock()

# Create mock module
mock_langchain_module = MagicMock()
mock_langchain_module.__spec__ = MagicMock()
mock_langchain_module.langchain = MagicMock(
LangchainLLM=mock_langchain_llm, is_langchain_llm=lambda x: True
)

with patch.dict(
"sys.modules",
{
"pandasai_langchain": mock_langchain_module,
"pandasai_langchain.langchain": mock_langchain_module.langchain,
},
):
# Execute
ConfigManager.validate_llm()

# Assert
assert mock_langchain_llm.call_count == 1


def test_validate_llm_no_action_needed():
# Setup
ConfigManager._config = MagicMock()
mock_llm = MagicMock()
ConfigManager._config.llm = mock_llm

# Case where no PANDABI_API_KEY and not a langchain LLM
with patch.dict(os.environ, {}, clear=True):
with patch("importlib.util.find_spec") as mock_find_spec:
mock_find_spec.return_value = None

# Execute
ConfigManager.validate_llm()

# Assert - llm should remain unchanged
assert ConfigManager._config.llm == mock_llm


def test_config_update():
# Setup
mock_config = MagicMock()
initial_config = {"key1": "value1", "key2": "value2"}
mock_config.model_dump = MagicMock(return_value=initial_config.copy())
ConfigManager._config = mock_config

# Create a mock for Config.from_dict
original_from_dict = Config.from_dict
Config.from_dict = MagicMock()

try:
# Execute
new_config = {"key2": "new_value2", "key3": "value3"}
ConfigManager.update(new_config)

# Assert
expected_config = {"key1": "value1", "key2": "new_value2", "key3": "value3"}
assert mock_config.model_dump.call_count == 1
Config.from_dict.assert_called_once_with(expected_config)
finally:
# Restore original from_dict method
Config.from_dict = original_from_dict


def test_config_set():
# Setup
test_config = {"key": "value"}
with patch.object(Config, "from_dict") as mock_from_dict, patch.object(
ConfigManager, "validate_llm"
) as mock_validate_llm:
# Execute
ConfigManager.set(test_config)

# Assert
mock_from_dict.assert_called_once_with(test_config)
mock_validate_llm.assert_called_once()


def test_config_from_dict():
# Test with default overrides
config_dict = {
"save_logs": False,
"verbose": True,
"enable_cache": False,
"max_retries": 5,
}

config = Config.from_dict(config_dict)

assert isinstance(config, Config)
assert config.save_logs == False
assert config.verbose == True
assert config.enable_cache == False
assert config.max_retries == 5
assert config.llm is None
assert isinstance(config.file_manager, DefaultFileManager)

# Test with minimal dict
minimal_config = Config.from_dict({})
assert isinstance(minimal_config, Config)
assert minimal_config.save_logs == True # default value
assert minimal_config.verbose == False # default value
assert minimal_config.enable_cache == True # default value
assert minimal_config.max_retries == 3 # default value

0 comments on commit b9782e9

Please sign in to comment.