Skip to content

Commit

Permalink
tests: add tests for config, smart dataframe and smart datalake
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Feb 3, 2025
1 parent 936c6f0 commit fab3458
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pandasai/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -38,6 +37,9 @@ def set(cls, config_dict: Dict[str, Any]) -> None:
@classmethod
def get(cls) -> Config:
"""Get the global configuration."""
if cls._config is None:
cls._config = Config()

if cls._config.llm is None and os.environ.get("PANDABI_API_KEY"):
from pandasai.llm.bamboo_llm import BambooLLM

Expand Down
123 changes: 123 additions & 0 deletions tests/unit_tests/smart_dataframe/test_smart_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import warnings

import pandas as pd
import pytest

from pandasai.config import Config
from pandasai.llm.fake import FakeLLM
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes


def test_smart_dataframe_init_basic():
# Create a sample dataframe
df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})

# Test initialization with minimal parameters
with pytest.warns(DeprecationWarning):
smart_df = SmartDataframe(df)

assert smart_df._original_import is df
assert isinstance(smart_df.dataframe, pd.DataFrame)
assert smart_df._table_name is None
assert smart_df._table_description is None
assert smart_df._custom_head is None


def test_smart_dataframe_init_with_all_params():
# Create sample dataframes
df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})
custom_head = pd.DataFrame({"A": [1], "B": ["x"]})
config = Config(llm=FakeLLM())

# Test initialization with all parameters
with pytest.warns(DeprecationWarning):
smart_df = SmartDataframe(
df,
name="test_df",
description="Test dataframe",
custom_head=custom_head,
config=config,
)

assert smart_df._original_import is df
assert isinstance(smart_df.dataframe, pd.DataFrame)
assert smart_df._table_name == "test_df"
assert smart_df._table_description == "Test dataframe"
assert smart_df._custom_head == custom_head.to_csv(index=False)
assert smart_df._agent._state._config == config


def test_smart_dataframe_deprecation_warning():
df = pd.DataFrame({"A": [1, 2, 3]})

with warnings.catch_warnings(record=True) as warning_info:
warnings.simplefilter("always")
SmartDataframe(df)

deprecation_warnings = [
w for w in warning_info if issubclass(w.category, DeprecationWarning)
]
assert len(deprecation_warnings) >= 1
assert "SmartDataframe will soon be deprecated" in str(
deprecation_warnings[0].message
)


def test_load_df_success():
# Create sample dataframes
original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})
with pytest.warns(DeprecationWarning):
smart_df = SmartDataframe(original_df)

# Test loading a new dataframe
new_df = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]})
loaded_df = smart_df.load_df(
new_df,
name="new_df",
description="New test dataframe",
custom_head=pd.DataFrame({"C": [4], "D": ["a"]}),
)

assert isinstance(loaded_df, pd.DataFrame)
assert loaded_df.equals(new_df)


def test_load_df_invalid_input():
# Create a sample dataframe
original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})
with pytest.warns(DeprecationWarning):
smart_df = SmartDataframe(original_df)

# Test loading invalid data
with pytest.raises(
ValueError, match="Invalid input data. We cannot convert it to a dataframe."
):
smart_df.load_df(
"not a dataframe",
name="invalid_df",
description="Invalid test data",
custom_head=None,
)


def test_load_smartdataframes():
# Create sample dataframes
df1 = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]})
df2 = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]})

# Create a config with FakeLLM
config = Config(llm=FakeLLM())

# Test loading regular pandas DataFrames
smart_dfs = load_smartdataframes([df1, df2], config)
assert len(smart_dfs) == 2
assert all(isinstance(df, SmartDataframe) for df in smart_dfs)
assert all(hasattr(df, "config") for df in smart_dfs)

# Test loading mixed pandas DataFrames and SmartDataframes
existing_smart_df = SmartDataframe(df1, config=config)
mixed_dfs = load_smartdataframes([existing_smart_df, df2], config)
assert len(mixed_dfs) == 2
assert mixed_dfs[0] is existing_smart_df # Should return the same instance
assert isinstance(mixed_dfs[1], SmartDataframe)
assert hasattr(mixed_dfs[1], "config")
45 changes: 45 additions & 0 deletions tests/unit_tests/smart_datalake/test_smart_datalake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from unittest.mock import Mock

import pandas as pd
import pytest

from pandasai.config import Config
from pandasai.smart_datalake import SmartDatalake


@pytest.fixture
def sample_dataframes():
df1 = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
df2 = pd.DataFrame({"C": [7, 8, 9], "D": [10, 11, 12]})
return [df1, df2]


def test_dfs_property(sample_dataframes):
# Create a mock agent with context
mock_agent = Mock()
mock_agent.context.dfs = sample_dataframes

# Create SmartDatalake instance
smart_datalake = SmartDatalake(sample_dataframes)
smart_datalake._agent = mock_agent # Inject mock agent

# Test that dfs property returns the correct dataframes
assert smart_datalake.dfs == sample_dataframes


def test_enable_cache(sample_dataframes):
# Create a mock agent with context and config
mock_config = Config(enable_cache=True)
mock_agent = Mock()
mock_agent.context.config = mock_config

# Create SmartDatalake instance
smart_datalake = SmartDatalake(sample_dataframes)
smart_datalake._agent = mock_agent # Inject mock agent

# Test that enable_cache property returns the correct value
assert smart_datalake.enable_cache is True

# Test with cache disabled
mock_config.enable_cache = False
assert smart_datalake.enable_cache is False
42 changes: 42 additions & 0 deletions tests/unit_tests/test_api_key_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from unittest.mock import patch

import pytest

from pandasai.config import APIKeyManager


def test_set_api_key():
# Setup
test_api_key = "test-api-key-123"

# Execute
with patch.dict(os.environ, {}, clear=True):
APIKeyManager.set(test_api_key)

# Assert
assert os.environ.get("PANDABI_API_KEY") == test_api_key
assert APIKeyManager._api_key == test_api_key


def test_get_api_key():
# Setup
test_api_key = "test-api-key-123"
APIKeyManager._api_key = test_api_key

# Execute
result = APIKeyManager.get()

# Assert
assert result == test_api_key


def test_get_api_key_when_none():
# Setup
APIKeyManager._api_key = None

# Execute
result = APIKeyManager.get()

# Assert
assert result is None
86 changes: 86 additions & 0 deletions tests/unit_tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
from importlib.util import find_spec
from unittest.mock import MagicMock, patch

import pytest

from pandasai.config import APIKeyManager, Config, ConfigManager
from pandasai.llm.bamboo_llm import BambooLLM


class TestConfigManager:
def setup_method(self):
# Reset the ConfigManager state before each test
ConfigManager._config = None
ConfigManager._initialized = False

def test_validate_llm_with_pandabi_api_key(self):
"""Test validate_llm when PANDABI_API_KEY is set"""
with patch.dict(os.environ, {"PANDABI_API_KEY": "test-key"}):
ConfigManager._config = MagicMock()
ConfigManager._config.llm = None

ConfigManager.validate_llm()

assert isinstance(ConfigManager._config.llm, BambooLLM)

def test_validate_llm_without_pandabi_api_key(self):
"""Test validate_llm when PANDABI_API_KEY is not set"""
with patch.dict(os.environ, {}, clear=True):
ConfigManager._config = MagicMock()
ConfigManager._config.llm = None

ConfigManager.validate_llm()

assert ConfigManager._config.llm is None

@pytest.mark.skipif(
find_spec("pandasai_langchain") is None,
reason="pandasai_langchain not installed",
)
def test_validate_llm_with_langchain(self):
"""Test validate_llm with langchain integration"""
from pandasai_langchain.langchain import LangchainLLM

mock_langchain_llm = MagicMock()
ConfigManager._config = MagicMock()
ConfigManager._config.llm = mock_langchain_llm

with patch("pandasai_langchain.langchain.is_langchain_llm", return_value=True):
ConfigManager.validate_llm()

assert isinstance(ConfigManager._config.llm, LangchainLLM)
assert ConfigManager._config.llm._llm == mock_langchain_llm

def test_update_config(self):
"""Test updating configuration with new values"""
# Initialize config with some initial values
initial_config = {"save_logs": True, "verbose": False}
ConfigManager._config = Config.from_dict(initial_config)

# Update with new values
update_dict = {"verbose": True, "enable_cache": False}
ConfigManager.update(update_dict)

# Verify the configuration was updated correctly
updated_config = ConfigManager._config.model_dump()
assert updated_config["save_logs"] is True # Original value preserved
assert updated_config["verbose"] is True # Value updated
assert updated_config["enable_cache"] is False # New value added

def test_set_api_key(self):
"""Test setting the API key"""
test_api_key = "test-api-key-123"

# Clear any existing API key
if "PANDABI_API_KEY" in os.environ:
del os.environ["PANDABI_API_KEY"]
APIKeyManager._api_key = None

# Set the API key
APIKeyManager.set(test_api_key)

# Verify the API key is set in both places
assert os.environ["PANDABI_API_KEY"] == test_api_key
assert APIKeyManager._api_key == test_api_key
assert APIKeyManager.get() == test_api_key # Also test the get method

0 comments on commit fab3458

Please sign in to comment.