Skip to content

Conversation

mathematicalmichael
Copy link

@mathematicalmichael mathematicalmichael commented Aug 31, 2025

This PR adds a class method from_model to create a LitTool from a Pydantic model, including setup and run methods for validation.

In full transparency, I don't know that this is the "best" way to accomplish the task at hand, but certainly consider it as one proposal.

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Adds a convenience method which allows converting existing classes which inherit from pydantic.BaseModel into LitTool classes.

Why is this useful?

A lot of use cases stop short of true tool-calling and really just require schema enforcement (e.g., named entity extraction). This method helps create "tools" which do not actually have a function associated with them.

Example:

import os
import json

from litai import LLM, LitTool
from pydantic import BaseModel
from typing import Literal

class RelationshipNode(BaseModel):
    source_entity: str
    target_entity: str
    relation: Literal["consumer", "producer", "partner"]

get_relationship = LitTool.from_model(RelationshipNode)

llm = LLM(
    model="google/gemini-2.5-flash",
    api_key=os.environ.get("LITAI_API_KEY"),
    max_retries=1,
)

response = llm.chat("Michael purchased credits from Lightning AI", tools=[get_relationship], system_prompt=None)
print(response)

# validate that tool can be called
if "function" not in response:
    raise AssertionError("No function call found in response")
# if available, proceed to check ability to call tool (not necessary in practice, just demonstrates compatibility)
result = llm.call_tool(response, tools=[get_relationship])
print(result)
python ../from_model.py
[{"function": {"arguments": "{\"target_entity\":\"Lightning AI\",\"source_entity\":\"Michael\",\"relation\":\"consumer\"}", "name": "RelationshipNode"}}]
source_entity='Michael' target_entity='Lightning AI' relation='consumer'

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.

Did you have fun?

yes!

Additional Information

what feels wrong to me ergonomics-wise with the example above is this use-case being served by the .chat interface instead of something dedicated for this purpose.

open to suggestions. in theory, the goal is to just go FROM text TO json that adheres to the pydantic model.

That question is distinct from the contribution in the PR, though.

Added a class method 'from_model' to create a LitTool from a Pydantic model, including setup and run methods for validation.
Comment on lines 129 to 131
def run(self, *args, **kwargs) -> Any:
# Default implementation: validate & return an instance
return model(*args, **kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious—how is this meant to be invoked? In the run method, it looks like it would just return a Pydantic model instance, right?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check out the shell output in the description. when invoked it does just return the model. if I didn't implement this, then calling tools would break, even though it isn't helpful to invoke. kind of just pass-through behavior

@Borda Borda changed the title LitTool.from_model method to create LitTool from Pydantic LitTool.from_model method to create LitTool from Pydantic Sep 1, 2025
Copy link

codecov bot commented Sep 1, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 85%. Comparing base (5c6f65c) to head (25d9551).

Additional details and impacted files
@@        Coverage Diff         @@
##           main   #57   +/-   ##
==================================
  Coverage    84%   85%           
==================================
  Files         8     8           
  Lines       431   443   +12     
==================================
+ Hits        364   376   +12     
  Misses       67    67           
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Borda Borda requested a review from bhimrazy September 3, 2025 06:05
@mathematicalmichael
Copy link
Author

mathematicalmichael commented Sep 6, 2025

this almost feels like it could be part of llm.classify

@mathematicalmichael
Copy link
Author

mathematicalmichael commented Sep 6, 2025

upon some further testing, i came across an issue when trying to use llm.call_tool using the tools created from pydantic models.

that call_tool method declares str as output, meaning I have two options to make this PR's implementation compatible:

(1) modify the run method of from_model:

            def run(self, *args, **kwargs) -> Any:  # type: ignore
                # Default implementation: validate & return an instance
                return model(*args, **kwargs).model_dump()  # <-- change here to make it json-serializable

(2) modify call_tool's return contract from Optional[str] to Optional[Union[str, BaseModel, list[BaseModel]]]

    @staticmethod
    def call_tool(
        response: Union[List[dict], dict, str], tools: Optional[Sequence[Union[LitTool, "StructuredTool"]]] = None
    ) -> Optional[Union[str, BaseModel, list[BaseModel]]]:
    ...
        try:
            return json.dumps(results) if len(results) > 1 else results[0]
        except TypeError:
            return results if len(results) > 1 else results[0]

my preference is (2) so that the invocation of call_tool actually returns the pydantic models (helpful for downstream data ingestion)

however, option (3) is also available: move the logic to a dedicated method such as classify (e.g. predict), which creates a bit of repeated code (doesn't bother me).

for example:

def predict(  # noqa: D417
        self,
        prompt: str,
        contracts: Sequence[type[BaseModel]],
        system_prompt: Optional[str] = None,
        model: Optional[str] = None,
        max_tokens: int = 500,
        images: Optional[Union[List[str], str]] = None,
        conversation: Optional[str] = None,
        metadata: Optional[Dict[str, str]] = None,
        stream: bool = False,
        auto_call_tools: bool = False,
        **kwargs: Any,
    ) -> Optional[Union[BaseModel, list[BaseModel]]]:
        """Sends a message to the LLM and retrieves a structured response based on the provided Pydantic models."""
        tools = [LitTool.from_model(c) for c in contracts]
        response = self.chat(
            prompt=prompt,
            system_prompt=system_prompt,
            model=model,
            max_tokens=max_tokens,
            images=images,
            conversation=conversation,
            metadata=metadata,
            stream=stream,
            tools=tools,
            auto_call_tools=auto_call_tools,
            **kwargs,
        )
        # Call tool(s) with the given response.
        if isinstance(response, str):
            try:
                response = json.loads(response)
            except json.JSONDecodeError:
                raise ValueError("Tool response is not a valid JSON string")

        results = []
        if isinstance(response, dict):
            response = [response]

        for tool_response in response:
            if not isinstance(tool_response, dict):
                continue
            tool_name = tool_response.get("function", {}).get("name")
            if not tool_name:
                continue
            tool_args = tool_response.get("function", {}).get("arguments", {})
            if isinstance(tool_args, str):
                try:
                    tool_args = json.loads(tool_args)
                except json.JSONDecodeError:
                    print(f"❌ Failed to parse tool arguments: {tool_args}")
                    return None
            if isinstance(tool_args, dict):
                tool_args = {k: v for k, v in tool_args.items() if v is not None}

            for tool in tools:
                if tool.name == tool_name:
                    results.append(tool.run(**tool_args))

        if len(results) == 0:
            return None

        return results if len(results) > 1 else results[0]

upside of this is a dedicated method and avoidance of the user needing to call LitTool.from_model explicitly.
(though I still think I'd like the try/except in call_tool for compatibility)

let me know which path is suitable and I'll push up another commit. @bhimrazy

@bhimrazy
Copy link

bhimrazy commented Sep 8, 2025

Hi @mathematicalmichael, thanks for the updates.

I’m a bit unsure about the purpose here — this feels more like structured data extraction than a tool implementation.

Let’s hear what the maintainers think, and you can proceed accordingly.
cc: @k223kim @aniketmaurya

From my perspective, this type of task is usually handled via a response_format parameter or by guiding the model with a system prompt.
Probably something like an llm.extract (or a dedicated API) would be a more natural fit.

@mathematicalmichael
Copy link
Author

that is correct @bhimrazy

structured extraction is the goal, tool use is almost identical under the hood though.

semantics aside (what to call the method), I did want to put the functionality forward (it's 95% of the business use cases I encounter).
I do think predict as a method name makes some sense.

@Danidapena
Copy link
Collaborator

Danidapena commented Sep 9, 2025

@mathematicalmichael I agree with you—option 2 feels like the best way forward. Option 3 has some interesting points, but it might be a bit harder to maintain.

@mathematicalmichael
Copy link
Author

re (3): I've been putting option (3) through its paces (hundreds of API calls via llm.predict) on a project (pointing to my predict branch).
In doing so, I found myself doing a result.model_dump() on the output pydantic object for actual use down-stream anyway, meaning that the approach in (2) + a string parsing function would probably work out better than a dedicated llm.predict, and yes - be simpler to maintain.

i'll push an update with (2) shortly. thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants