Skip to content

Commit

Permalink
refactor: Cleanup core udf param types + allow empty child workflow t…
Browse files Browse the repository at this point in the history
…rigger inputs (#813)
  • Loading branch information
daryllimyt authored Jan 30, 2025
1 parent c4dfad4 commit 628185d
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 75 deletions.
43 changes: 29 additions & 14 deletions registry/tracecat_registry/base/core/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import Annotated, Any, Literal

import nh3
from pydantic import Field
from tracecat.config import TRACECAT__ALLOWED_EMAIL_ATTRIBUTES
from typing_extensions import Doc

from tracecat_registry import RegistrySecret, registry, secrets

Expand Down Expand Up @@ -100,36 +100,51 @@ def _build_email_message(
secrets=[smtp_secret],
)
def send_email_smtp(
sender: Annotated[str, Field(..., description="Email address of the sender")],
sender: Annotated[
str,
Doc("Email address of the sender"),
],
recipients: Annotated[
list[str], Field(..., description="List of recipient email addresses")
list[str],
Doc("List of recipient email addresses"),
],
subject: Annotated[
str,
Doc("Subject of the email"),
],
body: Annotated[
str,
Doc("Body content of the email"),
],
subject: Annotated[str, Field(..., description="Subject of the email")],
body: Annotated[str, Field(..., description="Body content of the email")],
content_type: Annotated[
Literal["text/plain", "text/html"],
Field(
None,
description="Email content type ('text/plain' or 'text/html'). Defaults to 'text/plain'.",
Doc(
"Email content type ('text/plain' or 'text/html'). Defaults to 'text/plain'."
),
] = "text/plain",
timeout: Annotated[
float | None, Field(None, description="Timeout for SMTP operations in seconds")
float | None,
Doc("Timeout for SMTP operations in seconds"),
] = None,
headers: Annotated[
dict[str, str] | None, Field(None, description="Additional email headers")
dict[str, str] | None,
Doc("Additional email headers"),
] = None,
enable_starttls: Annotated[
bool, Field(False, description="Enable STARTTLS for secure connection")
bool,
Doc("Enable STARTTLS for secure connection"),
] = False,
enable_ssl: Annotated[
bool, Field(False, description="Enable SSL for secure connection")
bool,
Doc("Enable SSL for secure connection"),
] = False,
enable_auth: Annotated[
bool, Field(False, description="Enable SMTP authentication")
bool,
Doc("Enable SMTP authentication"),
] = False,
ignore_cert_errors: Annotated[
bool, Field(False, description="Ignore SSL certificate errors")
bool,
Doc("Ignore SSL certificate errors"),
] = False,
) -> dict[str, Any]:
"""Run a send email action.
Expand Down
5 changes: 3 additions & 2 deletions registry/tracecat_registry/base/core/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def add(
include_in_schema=False,
)
def my_function(
age: Annotated[int, Field(30, description="Persons age in years")],
name: Annotated[str, Field(description="Name of the person")] = None,
age: Annotated[int, Doc("Persons age in years")] = 30,
name: Annotated[str | None, Doc("Name of the person")] = None,
is_member: bool = False,
) -> Member:
"""My function
Expand All @@ -66,6 +66,7 @@ def my_function(
Stats
the result
"""
name = name or "John Doe"
return Member(name=name, age=age, is_member=is_member)


Expand Down
29 changes: 18 additions & 11 deletions registry/tracecat_registry/base/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# XXX(WARNING): Do not import __future__ annotations from typing
# This will cause class types to be resolved as strings

from typing import Annotated, Any, Literal
from typing import Annotated, Any, Literal, cast

from pydantic import Field
from tracecat.llm import DEFAULT_MODEL_TYPE, route_llm_call
from tracecat.llm import DEFAULT_MODEL_TYPE, ModelType, route_llm_call
from typing_extensions import Doc

from tracecat_registry import RegistrySecret, registry

Expand All @@ -30,12 +30,17 @@
secrets=[llm_secret],
)
async def ai_action(
prompt: Annotated[str, Field(description="The prompt to send to the AI")],
prompt: Annotated[
str,
Doc("The prompt to send to the AI"),
],
system_context: Annotated[
str, Field(description="The system context")
str,
Doc("The system context"),
] = DEFAULT_SYSTEM_CONTEXT,
execution_context: Annotated[
dict[str, Any] | None, Field(description="The current execution context")
dict[str, Any] | None,
Doc("The current execution context"),
] = None,
model: Annotated[
Literal[
Expand All @@ -50,12 +55,14 @@ async def ai_action(
"gpt-4-vision-preview",
"gpt-3.5-turbo-0125",
],
Field(
description="The AI Model to use. If you use an OpenAI model (gpt family), you must have the `OPENAI_API_KEY` secret set.",
Doc(
"The AI Model to use. If you use an OpenAI model (gpt family),"
" you must have the `OPENAI_API_KEY` secret set."
),
] = DEFAULT_MODEL_TYPE,
] = DEFAULT_MODEL_TYPE.value,
additional_config: Annotated[
dict[str, Any] | None, Field(description="Additional configuration")
dict[str, Any] | None,
Doc("Additional configuration"),
] = None,
):
exec_ctx_str = (
Expand All @@ -79,6 +86,6 @@ async def ai_action(
return await route_llm_call(
prompt=prompt,
system_context=system_context,
model=model,
model=cast(ModelType, model),
additional_config=additional_config,
)
16 changes: 10 additions & 6 deletions registry/tracecat_registry/base/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from typing import Annotated, Any

from pydantic import Field
from tracecat.expressions import functions
from typing_extensions import Doc

from tracecat_registry import registry

Expand All @@ -17,7 +17,10 @@
display_group="Data Transform",
)
def reshape(
value: Annotated[Any, Field(..., description="The value to reshape")],
value: Annotated[
Any,
Doc("The value to reshape"),
],
) -> Any:
return value

Expand All @@ -29,12 +32,13 @@ def reshape(
display_group="Data Transform",
)
def filter(
items: Annotated[list[Any], Field(..., description="A collection of items.")],
items: Annotated[
list[Any],
Doc("A collection of items."),
],
python_lambda: Annotated[
str,
Field(
..., description="A Python lambda function for filtering the collection."
),
Doc("A Python lambda function for filtering the collection."),
],
) -> Any:
return functions.filter_(items=items, python_lambda=python_lambda)
60 changes: 19 additions & 41 deletions registry/tracecat_registry/base/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated, Any, Literal

from pydantic import Field
from tracecat.identifiers.workflow import AnyWorkflowID
from typing_extensions import Doc

from tracecat_registry import RegistryActionError, registry

Expand All @@ -16,72 +16,50 @@ async def execute(
*,
workflow_id: Annotated[
AnyWorkflowID | None,
Field(
default=None,
description=(
"The ID of the child workflow to execute. Must be provided if workflow_alias is not provided."
),
Doc(
"The ID of the child workflow to execute. Must be provided if workflow_alias is not provided.",
),
] = None,
workflow_alias: Annotated[
str | None,
Field(
default=None,
description=(
"The alias of the child workflow to execute. Must be provided if workflow_id is not provided."
),
Doc(
"The alias of the child workflow to execute. Must be provided if workflow_id is not provided.",
),
] = None,
trigger_inputs: Annotated[
dict[str, Any],
Field(
...,
description="The inputs to pass to the child workflow.",
),
],
dict[str, Any] | None,
Doc("The inputs to pass to the child workflow."),
] = None,
environment: Annotated[
str | None,
Field(
description=(
"The child workflow's target execution environment. "
"This is used to isolate secrets across different environments."
"If not provided, the child workflow's default environment is used. "
),
Doc(
"The child workflow's target execution environment. "
"This is used to isolate secrets across different environments."
"If not provided, the child workflow's default environment is used. "
),
] = None,
timeout: Annotated[
float | None,
Field(
description=(
"The maximum number of seconds to wait for the child workflow to complete. "
"If not provided, the child workflow's default timeout is used. "
),
Doc(
"The maximum number of seconds to wait for the child workflow to complete. "
"If not provided, the child workflow's default timeout is used. "
),
] = None,
version: Annotated[
int | None,
Field(..., description="The version of the child workflow definition, if any."),
Doc("The version of the child workflow definition, if any."),
] = None,
loop_strategy: Annotated[
Literal["parallel", "batch", "sequential"],
Field(
...,
description="The execution strategy to use for the child workflow.",
),
Doc("The execution strategy to use for the child workflow."),
] = "parallel",
batch_size: Annotated[
int,
Field(
...,
description="The number of child workflows to execute in parallel.",
),
Doc("The number of child workflows to execute in parallel."),
] = 16,
fail_strategy: Annotated[
Literal["isolated", "all"],
Field(
...,
description="Fail strategy to use when a child workflow fails.",
),
Doc("Fail strategy to use when a child workflow fails."),
] = "isolated",
) -> Any:
raise RegistryActionError(
Expand Down
2 changes: 1 addition & 1 deletion tracecat/dsl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def validate_workflow_id(cls, v: AnyWorkflowID) -> WorkflowUUID:
class ExecuteChildWorkflowArgs(BaseModel):
workflow_id: WorkflowUUID | None = None
workflow_alias: str | None = None
trigger_inputs: TriggerInputs
trigger_inputs: TriggerInputs | None = None
environment: str | None = None
version: int | None = None
loop_strategy: LoopStrategy = LoopStrategy.BATCH
Expand Down

0 comments on commit 628185d

Please sign in to comment.