diff --git a/tierkreis/tierkreis/cli/run_workflow.py b/tierkreis/tierkreis/cli/run_workflow.py index 2446c73db..6d9f05fa3 100644 --- a/tierkreis/tierkreis/cli/run_workflow.py +++ b/tierkreis/tierkreis/cli/run_workflow.py @@ -2,6 +2,7 @@ import uuid import logging +from tierkreis.builder import GraphBuilder from tierkreis.controller import run_graph from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.location import Loc @@ -14,7 +15,7 @@ def run_workflow( - graph: GraphData, + graph: GraphData | GraphBuilder, inputs: dict[str, PType], name: str | None = None, run_id: int | None = None, @@ -56,6 +57,14 @@ def run_workflow( polling_interval_seconds, ) if print_output: - all_outputs = graph.nodes[graph.output_idx()].inputs + if isinstance(graph, GraphData): + nodes = graph.nodes + output_idx = graph.output_idx() + else: + graph_data = graph.get_data() + nodes = graph_data.nodes + output_idx = graph_data.output_idx() + + all_outputs = nodes[output_idx].inputs for output in all_outputs: print(f"{output}: {storage.read_output(Loc(), output)}") diff --git a/tierkreis/tierkreis/cli/tkr.py b/tierkreis/tierkreis/cli/tkr.py index 7edc5eb7b..1dc5c4c72 100644 --- a/tierkreis/tierkreis/cli/tkr.py +++ b/tierkreis/tierkreis/cli/tkr.py @@ -9,6 +9,7 @@ from typing import Any, Callable from tierkreis.cli.run_workflow import run_workflow +from tierkreis.builder import GraphBuilder from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.types import PType, ptype_from_bytes from tierkreis.exceptions import TierkreisError @@ -22,7 +23,7 @@ def _import_from_path(module_name: str, file_path: str) -> Any: return module -def load_graph(graph_input: str) -> GraphData: +def load_graph(graph_input: str) -> GraphData | GraphBuilder: if ":" not in graph_input: raise TierkreisError(f"Invalid argument: {graph_input}") module_name, function_name = graph_input.split(":") @@ -31,7 +32,9 @@ def load_graph(graph_input: str) -> GraphData: module = _import_from_path("graph_module", module_name) else: module = importlib.import_module(module_name, __package__) - build_submission_graph: Callable[[], GraphData] = getattr(module, function_name) + build_submission_graph: Callable[[], GraphData | GraphBuilder] = getattr( + module, function_name + ) return build_submission_graph() @@ -64,7 +67,7 @@ def parse_args( graph.add_argument( "-g", "--graph-location", - help="Fully qualifying name of a Callable () -> GraphData. " + help="Fully qualifying name of a Callable () -> GraphBuilder. " + "Example: tierkreis.cli.sample_graph:simple_eval" + "Or a path to a python file and function." + "Example: examples/hello_world/hello_world_graph.py:hello_graph", diff --git a/tierkreis/tierkreis/controller/__init__.py b/tierkreis/tierkreis/controller/__init__.py index 5db6d7e2a..4cce5c3d6 100644 --- a/tierkreis/tierkreis/controller/__init__.py +++ b/tierkreis/tierkreis/controller/__init__.py @@ -1,6 +1,7 @@ import logging from time import sleep +from tierkreis.builder import GraphBuilder from tierkreis.controller.data.graph import Eval, GraphData from tierkreis.controller.data.location import Loc from tierkreis.controller.data.types import PType, bytes_from_ptype, ptype_from_bytes @@ -18,11 +19,16 @@ def run_graph( storage: ControllerStorage, executor: ControllerExecutor, - g: GraphData, + graph: GraphData | GraphBuilder, graph_inputs: dict[str, PType] | PType, n_iterations: int = 10000, polling_interval_seconds: float = 0.01, ) -> None: + if isinstance(graph, GraphBuilder): + g = graph.get_data() + else: + g = graph + if not isinstance(graph_inputs, dict): graph_inputs = {"value": graph_inputs} remaining_inputs = g.remaining_inputs({k for k in graph_inputs.keys()}) diff --git a/tierkreis_visualization/tierkreis_visualization/cli.py b/tierkreis_visualization/tierkreis_visualization/cli.py index 5f6ddf3f3..ede51ea11 100644 --- a/tierkreis_visualization/tierkreis_visualization/cli.py +++ b/tierkreis_visualization/tierkreis_visualization/cli.py @@ -1,8 +1,21 @@ from __future__ import annotations import argparse +import os +from pathlib import Path +from tierkreis.builder import GraphBuilder +from watchfiles import PythonFilter, run_process -from tierkreis_visualization.main import start +from tierkreis.cli.tkr import load_graph +from tierkreis_visualization.main import start, visualize_graph + + +def _vizualize_from_location(graph_location: str) -> None: + graph = load_graph(graph_location) + if isinstance(graph, GraphBuilder): + visualize_graph(graph.get_data()) + else: + visualize_graph(graph) class TierkreisVizCli: @@ -16,6 +29,41 @@ def add_subcommand( ) parser.set_defaults(func=TierkreisVizCli.execute) + dev_parser = main_parser.add_parser( + "dev", description="Inspect a graph without running it" + ) + dev_parser.add_argument( + "-g", + "--graph-location", + help="Fully qualifying name of a Callable () -> GraphBuilder. " + + "Example: tierkreis.cli.sample_graph:simple_eval" + + "Or a path to a python file and function." + + "Example: examples/hello_world/hello_world_graph.py:hello_graph", + type=str, + required=True, + ) + dev_parser.set_defaults(func=TierkreisVizCli.dev) + @staticmethod def execute(args: argparse.Namespace) -> None: start() + + @staticmethod + def dev(args: argparse.Namespace) -> None: + module_name, _function_name = args.graph_location.split(":") + # By default watch the directory with the graph definition in for changes. + # + # This might be suboptimal when changes happen in e.g. the virtual env + # but this is reasonably similar to how fastapi dev server works. + if ".py" not in module_name: + path = Path(os.getcwd()).parent + else: + path = Path(module_name).parent + print(f"Listening for changes at {path}") + run_process( + path, + target=_vizualize_from_location, + args=(args.graph_location,), + watch_filter=PythonFilter(), + recursive=True, + )