Skip to content

Commit

Permalink
Getting closer
Browse files Browse the repository at this point in the history
  • Loading branch information
seanchatmangpt committed Jan 31, 2024
1 parent 0532a13 commit 2fe362e
Show file tree
Hide file tree
Showing 10 changed files with 609 additions and 70 deletions.
Empty file added assertion.log
Empty file.
24 changes: 24 additions & 0 deletions src/rdddy/generators/assertion.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
2024-01-24 10:35:48,759 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python Assert primitive type for
Assert_python_primitive_string
You will be penalized for not returning only a Assert for Assert_python_primitive_string
2024-01-24 10:36:00,019 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python Assert primitive type for
Assert_python_primitive_string
You will be penalized for not returning only a Assert for Assert_python_primitive_string
2024-01-24 10:38:53,382 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for
set_python_primitive_string
You will be penalized for not returning only a set for set_python_primitive_string
2024-01-24 10:39:14,656 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for
set_python_primitive_string
You will be penalized for not returning only a set for set_python_primitive_string
2024-01-24 10:39:38,514 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for
set_python_primitive_string
You will be penalized for not returning only a set for set_python_primitive_string
2024-01-24 10:44:50,239 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for
set_python_primitive_string
You will be penalized for not returning only a set for set_python_primitive_string
2024-01-24 10:45:49,389 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for
set_python_primitive_string
You will be penalized for not returning only a set for set_python_primitive_string
2024-01-24 10:46:13,339 - dspy.primitives.assertions - ERROR - AssertionError: You need to create a valid python set primitive type for
set_python_primitive_string
You will be penalized for not returning only a set for set_python_primitive_string
50 changes: 50 additions & 0 deletions src/rdddy/generators/gen_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging # Import the logging module
from dspy import Module, OpenAI, settings, ChainOfThought, Assert
from pydantic import ValidationError

logger = logging.getLogger(__name__) # Create a logger instance
logger.setLevel(logging.ERROR) # Set the logger's level to ERROR or the appropriate level


class GenModule(Module):
def __init__(self, output_key, input_keys: list[str] = None, lm=None):
if lm is None:
lm = OpenAI(max_tokens=500)
settings.configure(lm=lm)

if input_keys is None:
self.input_keys = ["prompt"]
else:
self.input_keys = input_keys

super().__init__()

self.output_key = output_key

# Define the generation and correction queries based on generation_type
self.signature = ', '.join(self.input_keys) + f" -> {self.output_key}"
self.correction_signature = ', '.join(self.input_keys) + f", error -> {self.output_key}"

# DSPy modules for generation and correction
self.generate = ChainOfThought(self.signature)
self.correct_generate = ChainOfThought(self.correction_signature)

def forward(self, **kwargs):
# Generate the output using provided inputs
gen_result = self.generate(**kwargs)
output = gen_result.get(self.output_key)

# Try validating the output
try:
return self.validate_output(output)
except (AssertionError, ValueError, TypeError) as error:
logger.error(error)
logger.error(output)
# Correction attempt
corrected_result = self.correct_generate(**kwargs, error=str(error))
corrected_output = corrected_result.get(self.output_key)
return self.validate_output(corrected_output)

def validate_output(self, output):
# Implement validation logic or override in subclass
raise NotImplementedError("Validation logic should be implemented in subclass")
101 changes: 101 additions & 0 deletions src/rdddy/generators/gen_pydantic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import dspy
import json
from typing import Dict, Any, Optional

import inspect
from dspy import Module, Assert

from pydantic import BaseModel, Field, ValidationError
from typing import List

from rdddy.generators.gen_module import GenModule
from rdddy.generators.gen_python_primitive import GenDict

def strip_text_before_first_open_brace(input_text):
if "{" in input_text:
return input_text[input_text.index("{"):]
else:
return input_text

class GenPydanticModel(GenModule):
def __init__(self, root_model, models: list = None):
if models is None:
models = [root_model]
elif root_model not in models:
models.append(root_model)

super().__init__(f"{root_model.__name__.lower()}_model_validate_json_dict", input_keys=["inspect_getsource", "prompt"])
self.root_model = root_model
self.models = models
self.model_sources = '\n'.join([inspect.getsource(model) for model in self.models])

def validate_root_model(self, output) -> bool:
try:
return isinstance(self.root_model.model_validate_json(output), self.root_model)
except (ValidationError, TypeError) as error:
return False

def validate_output(self, output):
output = strip_text_before_first_open_brace(str(output))

Assert(
self.validate_root_model(output),
f"""You need to create a dict for {self.root_model.__name__},
You will be penalized for not returning only a {self.root_model.__name__} dict for {self.output_key}""",
)

return self.root_model.model_validate_json(output)

def forward(self, **kwargs):
# spec = dspy.ChainOfThought("prompt, source -> instance")

# result = spec.forward(prompt=f'{kwargs["prompt"]}\nalign the prompt with the source', source=self.model_sources).instance

# return super().forward(inspect_getsource=self.model_sources, prompt=result)

return super().forward(inspect_getsource=self.model_sources, prompt=kwargs["prompt"])

# Create a detailed instruction for prompt refinement
# refinement_instruction = (
# "Below are the Pydantic model definitions:\n{}\n\n"
# "Based on these models, restructure the following description to align with the models:\n{}\n"
# "Restructured Description:"
# ).format(self.model_sources, kwargs["prompt"])
#
# # Use ChainOfThought for prompt refinement
# refined_prompt_result = dspy.ChainOfThought("prompt -> refined_prompt")
# refined_prompt = refined_prompt_result.forward(prompt=refinement_instruction).get("refined_prompt")
#
# # Proceed with the refined prompt
# return super().forward(inspect_getsource=self.model_sources, prompt=refined_prompt)


api_description = """
Service: Get Current Weather
Action: Retrieve
Path: /weather/current
Description: Fetches the latest weather information.
Parameters: { "location": "Specify a location as text" }
Output: Provides a JSON-based weather report with all the details.
APIEndpoint.model_validate_json(your_output)
"""

class APIEndpoint(BaseModel):
method: str = Field(..., description="HTTP method of the API endpoint")
url: str = Field(..., description="URL of the API endpoint")
description: str = Field(..., description="Description of what the API endpoint does")
response: str = Field(..., description="Response from the API endpoint")
query_params: Optional[Dict[str, Any]] = Field(None, description="Query parameters")


def main():
dot = GenPydanticModel(root_model=APIEndpoint)
result = dot.forward(prompt=api_description)
print(result)


if __name__ == '__main__':
main()

136 changes: 72 additions & 64 deletions src/rdddy/generators/gen_python_primitive.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,95 @@
import ast

from dspy import Module, OpenAI, settings, ChainOfThought, Assert
from dspy import Assert

from rdddy.generators.gen_module import GenModule

class GenPythonPrimitive(Module):
def __init__(self, primitive_type, lm=None):
if lm is None:
turbo = OpenAI(max_tokens=500)

settings.configure(lm=turbo)
def is_primitive_type(data_type):
primitive_types = {
int, float, str, bool, list, tuple, dict, set
}

super().__init__()
return data_type in primitive_types

if primitive_type is set:
raise ValueError("Set not supported.")

self.prompt = None
class GenPythonPrimitive(GenModule):
def __init__(self, primitive_type, lm=None):
if not is_primitive_type(primitive_type):
raise ValueError(f'primitive type {primitive_type.__name__} must be a Python primitive type')
super().__init__(f"{primitive_type.__name__}_python_primitive_pep8_string", lm)
self.primitive_type = primitive_type
self.output_key = f"{primitive_type.__name__}_python_primitive_string"
generation_query = f"prompt -> {self.output_key}"
correction_query = f"prompt, error -> {self.output_key}"

# DSPy modules for generation and correction
self.cot = ChainOfThought(generation_query)
self.correct_cot = ChainOfThought(correction_query)

def forward(self, prompt: str):
self.prompt = prompt
# Generate the primitive
cot_result = self.cot(prompt=prompt)
output = cot_result.get(self.output_key)

# Try validating the primitive
try:
if self.primitive_type is str:
return output
if self.primitive_type is bool and "false" in output.lower():
return False
if self.primitive_type is bool and "true" in output.lower():
return True

Assert(
self.validate_primitive(output),
f"You need to create a valid python {self.primitive_type.__name__} "
f"primitive type for \n{self.output_key}\n"
f"You will be penalized for not returning only a {self.primitive_type.__name__} for "
f"{self.output_key}",
)

return ast.literal_eval(output)
except (SyntaxError, AssertionError, ValueError) as error:
print(error)
# Try again
try:
cot_result = self.correct_cot(prompt=prompt, error=str(error))
output = cot_result.get(self.output_key)

return ast.literal_eval(output)
except (SyntaxError, ValueError) as error:
raise ValueError(
f"Unable to correctly generate a python "
f"{self.primitive_type.__name__} from {self.prompt}. "
)

def validate_primitive(self, output) -> bool:
try:
return isinstance(ast.literal_eval(output), self.primitive_type)
except SyntaxError as error:
return False

def validate_output(self, output):
Assert(
self.validate_primitive(output),
f"You need to create a valid python {self.primitive_type.__name__} "
f"primitive type for \n{self.output_key}\n"
f"You will be penalized for not returning only a {self.primitive_type.__name__} for "
f"{self.output_key}",
)
data = ast.literal_eval(output)

if self.primitive_type is set:
data = set(data)
return data

def __call__(self, prompt):
return self.forward(prompt=prompt)


class GenDict(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=dict)


class GenList(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=list)

def main():
module = GenPythonPrimitive(
primitive_type=list,
)

result = module.forward(
"Create a list of planets in our solar system sorted by largest to smallest"
)
class GenBool(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=bool)


class GenInt(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=int)


class GenFloat(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=float)


class GenTuple(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=tuple)


class GenSet(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=set)


class GenStr(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=str)


def main():
result = GenTuple()("Create a list of planets in our solar system sorted by largest to smallest")

assert result == ['Jupiter', 'Saturn', 'Uranus', 'Neptune', 'Earth', 'Venus', 'Mars', 'Mercury']
assert result == ('Jupiter', 'Saturn', 'Uranus', 'Neptune', 'Earth', 'Venus', 'Mars', 'Mercury')

print(f"The number of planets in the solar system is {result}")
print(f"The planets of the solar system are {result}")


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 2fe362e

Please sign in to comment.