Skip to content

Commit

Permalink
Merge pull request #153 from aitomatic/refactor
Browse files Browse the repository at this point in the history
refactor openssa.l2 submodules
  • Loading branch information
TheVinhLuong102 authored Apr 23, 2024
2 parents 578ff94 + 6ec8f6c commit ba6f604
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 126 deletions.
1 change: 1 addition & 0 deletions examples/FinanceBench/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.data/
.streamlit/secrets.toml
*.ipynb
*.log
13 changes: 2 additions & 11 deletions openssa/l2/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,11 @@
from openssa.l2.resource.abstract import AResource


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class AbstractAgent(ABC):
"""Abstract agent with planning, reasoning & informational resources."""

planner: APlanner
planner: APlanner | None = None
reasoner: AReasoner = field(default_factory=BaseReasoner)
resources: set[AResource] = field(default_factory=set,
init=True,
Expand Down
11 changes: 1 addition & 10 deletions openssa/l2/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,7 @@
from .abstract import AbstractAgent


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class Agent(AbstractAgent):
"""Agent with planning, reasoning & informational resources."""

Expand Down
22 changes: 2 additions & 20 deletions openssa/l2/planning/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,7 @@
from openssa.l2.task.abstract import ATask


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class AbstractPlan(ABC):
"""Abstract plan."""
task: ATask
Expand All @@ -38,16 +29,7 @@ def execute(self, reasoner: AReasoner = BaseReasoner()) -> str:
APlan: TypeVar = TypeVar('APlan', bound=AbstractPlan, covariant=False, contravariant=False)


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class AbstractPlanner(ABC):
"""Abstract planner."""

Expand Down
26 changes: 4 additions & 22 deletions openssa/l2/planning/hierarchical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,7 @@ class HTPDict(TypedDict, total=False):
type AskAnsPair = tuple[str, str]


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class HTP(AbstractPlan):
"""Hierarchical task plan (HTP)."""

Expand All @@ -67,8 +58,8 @@ def to_dict(self) -> HTPDict:
def fix_missing_resources(self):
"""Fix missing resources in HTP."""
for p in self.sub_plans:
if not p.task.resource:
p.task.resource: AResource | None = self.task.resource
if not p.task.resources:
p.task.resources: set[AResource] | None = self.task.resources
p.fix_missing_resources()

def execute(self, reasoner: AReasoner = BaseReasoner(), other_results: list[AskAnsPair] | None = None) -> str:
Expand Down Expand Up @@ -107,16 +98,7 @@ def execute(self, reasoner: AReasoner = BaseReasoner(), other_results: list[AskA
return self.task.result


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class AutoHTPlanner(AbstractPlanner):
"""Automated (generative) hierarchical task planner."""

Expand Down
11 changes: 1 addition & 10 deletions openssa/l2/reasoning/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,7 @@
from openssa.utils.llms import AnLLM, OpenAILLM


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class AbstractReasoner(ABC):
"""Abstract reasoner."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,11 @@

from dataclasses import dataclass

from openssa.l2.reasoning.abstract import AbstractReasoner
from openssa.l2.task.abstract import ATask

from .abstract import AbstractReasoner


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class BaseReasoner(AbstractReasoner):
"""Base reasoner."""

Expand Down
11 changes: 1 addition & 10 deletions openssa/l2/reasoning/ooda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@
from openssa.l2.reasoning.base import BaseReasoner


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class OodaReasoner(BaseReasoner):
"""OODA reasoner."""

Expand Down
10 changes: 1 addition & 9 deletions openssa/l2/resource/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,7 @@


@global_register
@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False)
@dataclass
class FileResource(AbstractResource):
"""File-stored informational resource."""

Expand Down
36 changes: 17 additions & 19 deletions openssa/l2/task/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,58 @@


from abc import ABC
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Self, TypedDict, Required, NotRequired, TypeVar

from openssa.l2.resource.abstract import AbstractResource
from openssa.l2.resource.abstract import AResource
from openssa.l2.resource._global import GLOBAL_RESOURCES

from .status import TaskStatus


class TaskDict(TypedDict, total=False):
ask: Required[str]
resource: NotRequired[AbstractResource]
resources: NotRequired[set[AResource]]
status: NotRequired[TaskStatus]
result: NotRequired[str]


@dataclass(init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False, # mutable
match_args=True,
kw_only=False,
slots=False,
weakref_slot=False)
@dataclass
class AbstractTask(ABC):
"""Abstract task."""

ask: str
resource: AbstractResource | None = None
resources: set[AResource] = field(default_factory=set,
init=True,
repr=True,
hash=False, # mutable
compare=True,
metadata=None,
kw_only=True)
status: TaskStatus = TaskStatus.PENDING
result: str | None = None

@classmethod
def from_dict(cls, d: TaskDict, /) -> Self:
"""Create resource instance from dictionary representation."""
"""Create task from dictionary representation."""
task: Self = cls(**d)

if isinstance(task.resource, str):
task.resource: AbstractResource = GLOBAL_RESOURCES[task.resource]
if task.resources:
task.resources: set[AResource] = {(GLOBAL_RESOURCES[r] if isinstance(r, str) else r)
for r in task.resources}

task.status: TaskStatus = TaskStatus(task.status)

return task

@classmethod
def from_str(cls, s: str, /) -> Self:
"""Create resource instance from dictionary representation."""
"""Create task from string representation."""
return cls(ask=s)

@classmethod
def from_dict_or_str(cls, dict_or_str: TaskDict | str, /) -> Self:
"""Create resource instance from dictionary or string representation."""
"""Create task from dictionary or string representation."""
if isinstance(dict_or_str, dict):
return cls.from_dict(dict_or_str)

Expand Down
5 changes: 2 additions & 3 deletions openssa/l2/task/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

class TaskStatus(StrEnum):
PENDING: str = auto()
IN_PROGRESS: str = auto()
NEEDING_DECOMPOSITION: str = auto()
DECOMPOSED: str = auto()
DONE: str = auto()
FAILED: str = auto()
TIMED_OUT: str = auto()

0 comments on commit ba6f604

Please sign in to comment.