Skip to content

Commit

Permalink
tests: add more tests for prompts and smart datalake
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Feb 3, 2025
1 parent fab3458 commit f59cb54
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 0 deletions.
84 changes: 84 additions & 0 deletions tests/unit_tests/core/prompts/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from unittest.mock import MagicMock, patch

import pytest
from jinja2 import Environment

from pandasai.core.prompts.base import BasePrompt


class TestBasePrompt:
def test_to_json_without_context(self):
# Given a BasePrompt instance without context
class TestPrompt(BasePrompt):
template = "Test template {{ var }}"

prompt = TestPrompt(var="value")

# When calling to_json
result = prompt.to_json()

# Then it should return a dict with only the prompt
assert isinstance(result, dict)
assert list(result.keys()) == ["prompt"]
assert result["prompt"] == "Test template value"

def test_to_json_with_context(self):
# Given a BasePrompt instance with context
class TestPrompt(BasePrompt):
template = "Test template {{ var }}"

memory = MagicMock()
memory.to_json.return_value = ["conversation1", "conversation2"]
memory.agent_description = "test agent"

context = MagicMock()
context.memory = memory

prompt = TestPrompt(var="value", context=context)

# When calling to_json
result = prompt.to_json()

# Then it should return a dict with conversation, system_prompt and prompt
assert isinstance(result, dict)
assert set(result.keys()) == {"conversation", "system_prompt", "prompt"}
assert result["conversation"] == ["conversation1", "conversation2"]
assert result["system_prompt"] == "test agent"
assert result["prompt"] == "Test template value"

def test_render_with_variables(self):
# Given a BasePrompt instance with a template containing variables
class TestPrompt(BasePrompt):
template = "Hello {{ name }}!\nHow are you?\n\n\n\nGoodbye {{ name }}!"

prompt = TestPrompt(name="World")

# When calling render
result = prompt.render()

# Then it should:
# 1. Replace variables correctly
# 2. Remove extra newlines (more than 2)
expected = "Hello World!\nHow are you?\n\nGoodbye World!"
assert result == expected

def test_render_with_template_path(self):
# Given a BasePrompt instance with a template path
class TestPrompt(BasePrompt):
template_path = "test_template.txt"

with patch.object(Environment, "get_template") as mock_get_template:
mock_template = MagicMock()
mock_template.render.return_value = "Hello\n\n\n\nWorld!"
mock_get_template.return_value = mock_template

prompt = TestPrompt(name="Test")

# When calling render
result = prompt.render()

# Then it should:
# 1. Use the template from file
# 2. Remove extra newlines
assert result == "Hello\n\nWorld!"
mock_template.render.assert_called_once_with(name="Test")
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from unittest.mock import Mock, patch

import pytest

from pandasai.core.prompts.correct_execute_sql_query_usage_error_prompt import (
CorrectExecuteSQLQueryUsageErrorPrompt,
)


def test_to_json():
# Mock the dependencies
mock_dataset = Mock()
mock_dataset.to_json.return_value = {"mock_dataset": "data"}

mock_memory = Mock()
mock_memory.to_json.return_value = {"mock_conversation": "data"}
mock_memory.agent_description = "Mock agent description"

mock_context = Mock()
mock_context.memory = mock_memory
mock_context.dfs = [mock_dataset]

# Create test data
test_code = "SELECT * FROM table"
test_error = Exception("Test error")

# Create instance of the prompt class
prompt = CorrectExecuteSQLQueryUsageErrorPrompt(
context=mock_context,
code=test_code,
error=test_error,
)

# Call the method
result = prompt.to_json()

# Assertions
assert result == {
"datasets": [{"mock_dataset": "data"}],
"conversation": {"mock_conversation": "data"},
"system_prompt": "Mock agent description",
"error": {
"code": test_code,
"error_trace": str(test_error),
"exception_type": "ExecuteSQLQueryNotUsed",
},
}

# Verify the mocks were called
mock_dataset.to_json.assert_called_once()
mock_memory.to_json.assert_called_once()
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from unittest.mock import Mock, patch

import pytest

from pandasai.core.prompts.correct_output_type_error_prompt import (
CorrectOutputTypeErrorPrompt,
)


def test_to_json():
# Mock the necessary dependencies
mock_memory = Mock()
mock_memory.to_json.return_value = {"conversations": "test"}
mock_memory.agent_description = "test agent"

mock_dataset = Mock()
mock_dataset.to_json.return_value = {"data": "test data"}

mock_context = Mock()
mock_context.memory = mock_memory
mock_context.dfs = [mock_dataset]

# Create test data
props = {
"context": mock_context,
"code": "test code",
"error": Exception("test error"),
"output_type": "test_type",
}

# Create instance of prompt
prompt = CorrectOutputTypeErrorPrompt(**props)

# Call to_json method
result = prompt.to_json()

# Verify the structure and content of the result
assert isinstance(result, dict)
assert "datasets" in result
assert "conversation" in result
assert "system_prompt" in result
assert "error" in result
assert "config" in result

# Verify specific values
assert result["datasets"] == [{"data": "test data"}]
assert result["conversation"] == {"conversations": "test"}
assert result["system_prompt"] == "test agent"
assert result["error"] == {
"code": "test code",
"error_trace": "test error",
"exception_type": "InvalidLLMOutputType",
}
assert result["config"] == {"output_type": "test_type"}

# Verify that the mock methods were called
mock_memory.to_json.assert_called_once()
mock_dataset.to_json.assert_called_once()
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from unittest.mock import Mock, patch

import pytest

from pandasai.core.prompts import GeneratePythonCodeWithSQLPrompt


@pytest.fixture
def mock_context():
context = Mock()
context.memory = Mock()
context.memory.to_json.return_value = {"history": []}
context.memory.agent_description = "Test Agent Description"
context.dfs = [Mock()]
context.dfs[0].to_json.return_value = {"name": "test_df", "data": []}
context.config.direct_sql = True
return context


def test_to_json(mock_context):
"""Test that to_json returns the expected structure with all required fields"""
prompt = GeneratePythonCodeWithSQLPrompt(context=mock_context, output_type="code")

# Mock the to_string method
with patch.object(prompt, "to_string", return_value="test prompt"):
result = prompt.to_json()

assert isinstance(result, dict)
assert "datasets" in result
assert isinstance(result["datasets"], list)
assert len(result["datasets"]) == 1
assert result["datasets"][0] == {"name": "test_df", "data": []}

assert "conversation" in result
assert result["conversation"] == {"history": []}

assert "system_prompt" in result
assert result["system_prompt"] == "Test Agent Description"

assert "prompt" in result
assert result["prompt"] == "test prompt"

assert "config" in result
assert isinstance(result["config"], dict)
assert "direct_sql" in result["config"]
assert result["config"]["direct_sql"] is True
assert "output_type" in result["config"]
assert result["config"]["output_type"] == "code"
26 changes: 26 additions & 0 deletions tests/unit_tests/smart_datalake/test_smart_datalake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from pandasai.config import Config
from pandasai.core.cache import Cache
from pandasai.smart_datalake import SmartDatalake


Expand Down Expand Up @@ -43,3 +44,28 @@ def test_enable_cache(sample_dataframes):
# Test with cache disabled
mock_config.enable_cache = False
assert smart_datalake.enable_cache is False


def test_enable_cache_setter(sample_dataframes):
# Create a mock agent with context and config
mock_config = Config(enable_cache=False)
mock_agent = Mock()
mock_agent.context = Mock()
mock_agent.context.config = mock_config
mock_agent.context.cache = None

# Create SmartDatalake instance
smart_datalake = SmartDatalake(sample_dataframes)
smart_datalake._agent = mock_agent # Inject mock agent

# Enable cache
smart_datalake.enable_cache = True
assert mock_agent.context.config.enable_cache is True
# Cache should be created and set in agent context
assert smart_datalake._cache is not None
assert isinstance(smart_datalake._cache, Cache)

# Disable cache
smart_datalake.enable_cache = False
assert mock_agent.context.config.enable_cache is False
assert smart_datalake._cache is None

0 comments on commit f59cb54

Please sign in to comment.