-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tests: add more tests for prompts and smart datalake
- Loading branch information
Showing
5 changed files
with
267 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
from jinja2 import Environment | ||
|
||
from pandasai.core.prompts.base import BasePrompt | ||
|
||
|
||
class TestBasePrompt: | ||
def test_to_json_without_context(self): | ||
# Given a BasePrompt instance without context | ||
class TestPrompt(BasePrompt): | ||
template = "Test template {{ var }}" | ||
|
||
prompt = TestPrompt(var="value") | ||
|
||
# When calling to_json | ||
result = prompt.to_json() | ||
|
||
# Then it should return a dict with only the prompt | ||
assert isinstance(result, dict) | ||
assert list(result.keys()) == ["prompt"] | ||
assert result["prompt"] == "Test template value" | ||
|
||
def test_to_json_with_context(self): | ||
# Given a BasePrompt instance with context | ||
class TestPrompt(BasePrompt): | ||
template = "Test template {{ var }}" | ||
|
||
memory = MagicMock() | ||
memory.to_json.return_value = ["conversation1", "conversation2"] | ||
memory.agent_description = "test agent" | ||
|
||
context = MagicMock() | ||
context.memory = memory | ||
|
||
prompt = TestPrompt(var="value", context=context) | ||
|
||
# When calling to_json | ||
result = prompt.to_json() | ||
|
||
# Then it should return a dict with conversation, system_prompt and prompt | ||
assert isinstance(result, dict) | ||
assert set(result.keys()) == {"conversation", "system_prompt", "prompt"} | ||
assert result["conversation"] == ["conversation1", "conversation2"] | ||
assert result["system_prompt"] == "test agent" | ||
assert result["prompt"] == "Test template value" | ||
|
||
def test_render_with_variables(self): | ||
# Given a BasePrompt instance with a template containing variables | ||
class TestPrompt(BasePrompt): | ||
template = "Hello {{ name }}!\nHow are you?\n\n\n\nGoodbye {{ name }}!" | ||
|
||
prompt = TestPrompt(name="World") | ||
|
||
# When calling render | ||
result = prompt.render() | ||
|
||
# Then it should: | ||
# 1. Replace variables correctly | ||
# 2. Remove extra newlines (more than 2) | ||
expected = "Hello World!\nHow are you?\n\nGoodbye World!" | ||
assert result == expected | ||
|
||
def test_render_with_template_path(self): | ||
# Given a BasePrompt instance with a template path | ||
class TestPrompt(BasePrompt): | ||
template_path = "test_template.txt" | ||
|
||
with patch.object(Environment, "get_template") as mock_get_template: | ||
mock_template = MagicMock() | ||
mock_template.render.return_value = "Hello\n\n\n\nWorld!" | ||
mock_get_template.return_value = mock_template | ||
|
||
prompt = TestPrompt(name="Test") | ||
|
||
# When calling render | ||
result = prompt.render() | ||
|
||
# Then it should: | ||
# 1. Use the template from file | ||
# 2. Remove extra newlines | ||
assert result == "Hello\n\nWorld!" | ||
mock_template.render.assert_called_once_with(name="Test") |
51 changes: 51 additions & 0 deletions
51
tests/unit_tests/core/prompts/test_correct_execute_sql_query_usage_error_prompt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
|
||
from pandasai.core.prompts.correct_execute_sql_query_usage_error_prompt import ( | ||
CorrectExecuteSQLQueryUsageErrorPrompt, | ||
) | ||
|
||
|
||
def test_to_json(): | ||
# Mock the dependencies | ||
mock_dataset = Mock() | ||
mock_dataset.to_json.return_value = {"mock_dataset": "data"} | ||
|
||
mock_memory = Mock() | ||
mock_memory.to_json.return_value = {"mock_conversation": "data"} | ||
mock_memory.agent_description = "Mock agent description" | ||
|
||
mock_context = Mock() | ||
mock_context.memory = mock_memory | ||
mock_context.dfs = [mock_dataset] | ||
|
||
# Create test data | ||
test_code = "SELECT * FROM table" | ||
test_error = Exception("Test error") | ||
|
||
# Create instance of the prompt class | ||
prompt = CorrectExecuteSQLQueryUsageErrorPrompt( | ||
context=mock_context, | ||
code=test_code, | ||
error=test_error, | ||
) | ||
|
||
# Call the method | ||
result = prompt.to_json() | ||
|
||
# Assertions | ||
assert result == { | ||
"datasets": [{"mock_dataset": "data"}], | ||
"conversation": {"mock_conversation": "data"}, | ||
"system_prompt": "Mock agent description", | ||
"error": { | ||
"code": test_code, | ||
"error_trace": str(test_error), | ||
"exception_type": "ExecuteSQLQueryNotUsed", | ||
}, | ||
} | ||
|
||
# Verify the mocks were called | ||
mock_dataset.to_json.assert_called_once() | ||
mock_memory.to_json.assert_called_once() |
58 changes: 58 additions & 0 deletions
58
tests/unit_tests/core/prompts/test_correct_output_type_error_prompt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
|
||
from pandasai.core.prompts.correct_output_type_error_prompt import ( | ||
CorrectOutputTypeErrorPrompt, | ||
) | ||
|
||
|
||
def test_to_json(): | ||
# Mock the necessary dependencies | ||
mock_memory = Mock() | ||
mock_memory.to_json.return_value = {"conversations": "test"} | ||
mock_memory.agent_description = "test agent" | ||
|
||
mock_dataset = Mock() | ||
mock_dataset.to_json.return_value = {"data": "test data"} | ||
|
||
mock_context = Mock() | ||
mock_context.memory = mock_memory | ||
mock_context.dfs = [mock_dataset] | ||
|
||
# Create test data | ||
props = { | ||
"context": mock_context, | ||
"code": "test code", | ||
"error": Exception("test error"), | ||
"output_type": "test_type", | ||
} | ||
|
||
# Create instance of prompt | ||
prompt = CorrectOutputTypeErrorPrompt(**props) | ||
|
||
# Call to_json method | ||
result = prompt.to_json() | ||
|
||
# Verify the structure and content of the result | ||
assert isinstance(result, dict) | ||
assert "datasets" in result | ||
assert "conversation" in result | ||
assert "system_prompt" in result | ||
assert "error" in result | ||
assert "config" in result | ||
|
||
# Verify specific values | ||
assert result["datasets"] == [{"data": "test data"}] | ||
assert result["conversation"] == {"conversations": "test"} | ||
assert result["system_prompt"] == "test agent" | ||
assert result["error"] == { | ||
"code": "test code", | ||
"error_trace": "test error", | ||
"exception_type": "InvalidLLMOutputType", | ||
} | ||
assert result["config"] == {"output_type": "test_type"} | ||
|
||
# Verify that the mock methods were called | ||
mock_memory.to_json.assert_called_once() | ||
mock_dataset.to_json.assert_called_once() |
48 changes: 48 additions & 0 deletions
48
tests/unit_tests/core/prompts/test_generate_python_code_with_sql_prompt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
|
||
from pandasai.core.prompts import GeneratePythonCodeWithSQLPrompt | ||
|
||
|
||
@pytest.fixture | ||
def mock_context(): | ||
context = Mock() | ||
context.memory = Mock() | ||
context.memory.to_json.return_value = {"history": []} | ||
context.memory.agent_description = "Test Agent Description" | ||
context.dfs = [Mock()] | ||
context.dfs[0].to_json.return_value = {"name": "test_df", "data": []} | ||
context.config.direct_sql = True | ||
return context | ||
|
||
|
||
def test_to_json(mock_context): | ||
"""Test that to_json returns the expected structure with all required fields""" | ||
prompt = GeneratePythonCodeWithSQLPrompt(context=mock_context, output_type="code") | ||
|
||
# Mock the to_string method | ||
with patch.object(prompt, "to_string", return_value="test prompt"): | ||
result = prompt.to_json() | ||
|
||
assert isinstance(result, dict) | ||
assert "datasets" in result | ||
assert isinstance(result["datasets"], list) | ||
assert len(result["datasets"]) == 1 | ||
assert result["datasets"][0] == {"name": "test_df", "data": []} | ||
|
||
assert "conversation" in result | ||
assert result["conversation"] == {"history": []} | ||
|
||
assert "system_prompt" in result | ||
assert result["system_prompt"] == "Test Agent Description" | ||
|
||
assert "prompt" in result | ||
assert result["prompt"] == "test prompt" | ||
|
||
assert "config" in result | ||
assert isinstance(result["config"], dict) | ||
assert "direct_sql" in result["config"] | ||
assert result["config"]["direct_sql"] is True | ||
assert "output_type" in result["config"] | ||
assert result["config"]["output_type"] == "code" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters