-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0532a13
commit 2fe362e
Showing
10 changed files
with
609 additions
and
70 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
2024-01-24 10:35:48,759 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python Assert primitive type for | ||
Assert_python_primitive_string | ||
You will be penalized for not returning only a Assert for Assert_python_primitive_string | ||
2024-01-24 10:36:00,019 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python Assert primitive type for | ||
Assert_python_primitive_string | ||
You will be penalized for not returning only a Assert for Assert_python_primitive_string | ||
2024-01-24 10:38:53,382 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for | ||
set_python_primitive_string | ||
You will be penalized for not returning only a set for set_python_primitive_string | ||
2024-01-24 10:39:14,656 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for | ||
set_python_primitive_string | ||
You will be penalized for not returning only a set for set_python_primitive_string | ||
2024-01-24 10:39:38,514 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for | ||
set_python_primitive_string | ||
You will be penalized for not returning only a set for set_python_primitive_string | ||
2024-01-24 10:44:50,239 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for | ||
set_python_primitive_string | ||
You will be penalized for not returning only a set for set_python_primitive_string | ||
2024-01-24 10:45:49,389 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for | ||
set_python_primitive_string | ||
You will be penalized for not returning only a set for set_python_primitive_string | ||
2024-01-24 10:46:13,339 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for | ||
set_python_primitive_string | ||
You will be penalized for not returning only a set for set_python_primitive_string |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import logging # Import the logging module | ||
from dspy import Module, OpenAI, settings, ChainOfThought, Assert | ||
from pydantic import ValidationError | ||
|
||
logger = logging.getLogger(__name__) # Create a logger instance | ||
logger.setLevel(logging.ERROR) # Set the logger's level to ERROR or the appropriate level | ||
|
||
|
||
class GenModule(Module): | ||
def __init__(self, output_key, input_keys: list[str] = None, lm=None): | ||
if lm is None: | ||
lm = OpenAI(max_tokens=500) | ||
settings.configure(lm=lm) | ||
|
||
if input_keys is None: | ||
self.input_keys = ["prompt"] | ||
else: | ||
self.input_keys = input_keys | ||
|
||
super().__init__() | ||
|
||
self.output_key = output_key | ||
|
||
# Define the generation and correction queries based on generation_type | ||
self.signature = ', '.join(self.input_keys) + f" -> {self.output_key}" | ||
self.correction_signature = ', '.join(self.input_keys) + f", error -> {self.output_key}" | ||
|
||
# DSPy modules for generation and correction | ||
self.generate = ChainOfThought(self.signature) | ||
self.correct_generate = ChainOfThought(self.correction_signature) | ||
|
||
def forward(self, **kwargs): | ||
# Generate the output using provided inputs | ||
gen_result = self.generate(**kwargs) | ||
output = gen_result.get(self.output_key) | ||
|
||
# Try validating the output | ||
try: | ||
return self.validate_output(output) | ||
except (AssertionError, ValueError, TypeError) as error: | ||
logger.error(error) | ||
logger.error(output) | ||
# Correction attempt | ||
corrected_result = self.correct_generate(**kwargs, error=str(error)) | ||
corrected_output = corrected_result.get(self.output_key) | ||
return self.validate_output(corrected_output) | ||
|
||
def validate_output(self, output): | ||
# Implement validation logic or override in subclass | ||
raise NotImplementedError("Validation logic should be implemented in subclass") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import dspy | ||
import json | ||
from typing import Dict, Any, Optional | ||
|
||
import inspect | ||
from dspy import Module, Assert | ||
|
||
from pydantic import BaseModel, Field, ValidationError | ||
from typing import List | ||
|
||
from rdddy.generators.gen_module import GenModule | ||
from rdddy.generators.gen_python_primitive import GenDict | ||
|
||
def strip_text_before_first_open_brace(input_text): | ||
if "{" in input_text: | ||
return input_text[input_text.index("{"):] | ||
else: | ||
return input_text | ||
|
||
class GenPydanticModel(GenModule): | ||
def __init__(self, root_model, models: list = None): | ||
if models is None: | ||
models = [root_model] | ||
elif root_model not in models: | ||
models.append(root_model) | ||
|
||
super().__init__(f"{root_model.__name__.lower()}_model_validate_json_dict", input_keys=["inspect_getsource", "prompt"]) | ||
self.root_model = root_model | ||
self.models = models | ||
self.model_sources = '\n'.join([inspect.getsource(model) for model in self.models]) | ||
|
||
def validate_root_model(self, output) -> bool: | ||
try: | ||
return isinstance(self.root_model.model_validate_json(output), self.root_model) | ||
except (ValidationError, TypeError) as error: | ||
return False | ||
|
||
def validate_output(self, output): | ||
output = strip_text_before_first_open_brace(str(output)) | ||
|
||
Assert( | ||
self.validate_root_model(output), | ||
f"""You need to create a dict for {self.root_model.__name__}, | ||
You will be penalized for not returning only a {self.root_model.__name__} dict for {self.output_key}""", | ||
) | ||
|
||
return self.root_model.model_validate_json(output) | ||
|
||
def forward(self, **kwargs): | ||
# spec = dspy.ChainOfThought("prompt, source -> instance") | ||
|
||
# result = spec.forward(prompt=f'{kwargs["prompt"]}\nalign the prompt with the source', source=self.model_sources).instance | ||
|
||
# return super().forward(inspect_getsource=self.model_sources, prompt=result) | ||
|
||
return super().forward(inspect_getsource=self.model_sources, prompt=kwargs["prompt"]) | ||
|
||
# Create a detailed instruction for prompt refinement | ||
# refinement_instruction = ( | ||
# "Below are the Pydantic model definitions:\n{}\n\n" | ||
# "Based on these models, restructure the following description to align with the models:\n{}\n" | ||
# "Restructured Description:" | ||
# ).format(self.model_sources, kwargs["prompt"]) | ||
# | ||
# # Use ChainOfThought for prompt refinement | ||
# refined_prompt_result = dspy.ChainOfThought("prompt -> refined_prompt") | ||
# refined_prompt = refined_prompt_result.forward(prompt=refinement_instruction).get("refined_prompt") | ||
# | ||
# # Proceed with the refined prompt | ||
# return super().forward(inspect_getsource=self.model_sources, prompt=refined_prompt) | ||
|
||
|
||
api_description = """ | ||
Service: Get Current Weather | ||
Action: Retrieve | ||
Path: /weather/current | ||
Description: Fetches the latest weather information. | ||
Parameters: { "location": "Specify a location as text" } | ||
Output: Provides a JSON-based weather report with all the details. | ||
APIEndpoint.model_validate_json(your_output) | ||
""" | ||
|
||
class APIEndpoint(BaseModel): | ||
method: str = Field(..., description="HTTP method of the API endpoint") | ||
url: str = Field(..., description="URL of the API endpoint") | ||
description: str = Field(..., description="Description of what the API endpoint does") | ||
response: str = Field(..., description="Response from the API endpoint") | ||
query_params: Optional[Dict[str, Any]] = Field(None, description="Query parameters") | ||
|
||
|
||
def main(): | ||
dot = GenPydanticModel(root_model=APIEndpoint) | ||
result = dot.forward(prompt=api_description) | ||
print(result) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.