Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tierkreis/tierkreis/cli/run_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import uuid
import logging

from tierkreis.builder import GraphBuilder
from tierkreis.controller import run_graph
from tierkreis.controller.data.graph import GraphData
from tierkreis.controller.data.location import Loc
Expand All @@ -14,7 +15,7 @@


def run_workflow(
graph: GraphData,
graph: GraphData | GraphBuilder,
inputs: dict[str, PType],
name: str | None = None,
run_id: int | None = None,
Expand Down Expand Up @@ -56,6 +57,14 @@ def run_workflow(
polling_interval_seconds,
)
if print_output:
all_outputs = graph.nodes[graph.output_idx()].inputs
if isinstance(graph, GraphData):
nodes = graph.nodes
output_idx = graph.output_idx()
else:
graph_data = graph.get_data()
nodes = graph_data.nodes
output_idx = graph_data.output_idx()

all_outputs = nodes[output_idx].inputs
for output in all_outputs:
print(f"{output}: {storage.read_output(Loc(), output)}")
9 changes: 6 additions & 3 deletions tierkreis/tierkreis/cli/tkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable

from tierkreis.cli.run_workflow import run_workflow
from tierkreis.builder import GraphBuilder
from tierkreis.controller.data.graph import GraphData
from tierkreis.controller.data.types import PType, ptype_from_bytes
from tierkreis.exceptions import TierkreisError
Expand All @@ -22,7 +23,7 @@ def _import_from_path(module_name: str, file_path: str) -> Any:
return module


def load_graph(graph_input: str) -> GraphData:
def load_graph(graph_input: str) -> GraphData | GraphBuilder:
if ":" not in graph_input:
raise TierkreisError(f"Invalid argument: {graph_input}")
module_name, function_name = graph_input.split(":")
Expand All @@ -31,7 +32,9 @@ def load_graph(graph_input: str) -> GraphData:
module = _import_from_path("graph_module", module_name)
else:
module = importlib.import_module(module_name, __package__)
build_submission_graph: Callable[[], GraphData] = getattr(module, function_name)
build_submission_graph: Callable[[], GraphData | GraphBuilder] = getattr(
module, function_name
)

return build_submission_graph()

Expand Down Expand Up @@ -64,7 +67,7 @@ def parse_args(
graph.add_argument(
"-g",
"--graph-location",
help="Fully qualifying name of a Callable () -> GraphData. "
help="Fully qualifying name of a Callable () -> GraphBuilder. "
+ "Example: tierkreis.cli.sample_graph:simple_eval"
+ "Or a path to a python file and function."
+ "Example: examples/hello_world/hello_world_graph.py:hello_graph",
Expand Down
8 changes: 7 additions & 1 deletion tierkreis/tierkreis/controller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from time import sleep

from tierkreis.builder import GraphBuilder
from tierkreis.controller.data.graph import Eval, GraphData
from tierkreis.controller.data.location import Loc
from tierkreis.controller.data.types import PType, bytes_from_ptype, ptype_from_bytes
Expand All @@ -18,11 +19,16 @@
def run_graph(
storage: ControllerStorage,
executor: ControllerExecutor,
g: GraphData,
graph: GraphData | GraphBuilder,
graph_inputs: dict[str, PType] | PType,
n_iterations: int = 10000,
polling_interval_seconds: float = 0.01,
) -> None:
if isinstance(graph, GraphBuilder):
g = graph.get_data()
else:
g = graph

if not isinstance(graph_inputs, dict):
graph_inputs = {"value": graph_inputs}
remaining_inputs = g.remaining_inputs({k for k in graph_inputs.keys()})
Expand Down
50 changes: 49 additions & 1 deletion tierkreis_visualization/tierkreis_visualization/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
from __future__ import annotations
import argparse
import os
from pathlib import Path

from tierkreis.builder import GraphBuilder
from watchfiles import PythonFilter, run_process

from tierkreis_visualization.main import start
from tierkreis.cli.tkr import load_graph
from tierkreis_visualization.main import start, visualize_graph


def _vizualize_from_location(graph_location: str) -> None:
graph = load_graph(graph_location)
if isinstance(graph, GraphBuilder):
visualize_graph(graph.get_data())
else:
visualize_graph(graph)


class TierkreisVizCli:
Expand All @@ -16,6 +29,41 @@ def add_subcommand(
)
parser.set_defaults(func=TierkreisVizCli.execute)

dev_parser = main_parser.add_parser(
"dev", description="Inspect a graph without running it"
)
dev_parser.add_argument(
"-g",
"--graph-location",
help="Fully qualifying name of a Callable () -> GraphBuilder. "
+ "Example: tierkreis.cli.sample_graph:simple_eval"
+ "Or a path to a python file and function."
+ "Example: examples/hello_world/hello_world_graph.py:hello_graph",
type=str,
required=True,
)
dev_parser.set_defaults(func=TierkreisVizCli.dev)

@staticmethod
def execute(args: argparse.Namespace) -> None:
start()

@staticmethod
def dev(args: argparse.Namespace) -> None:
module_name, _function_name = args.graph_location.split(":")
# By default watch the directory with the graph definition in for changes.
#
# This might be suboptimal when changes happen in e.g. the virtual env
# but this is reasonably similar to how fastapi dev server works.
if ".py" not in module_name:
path = Path(os.getcwd()).parent
else:
path = Path(module_name).parent
print(f"Listening for changes at {path}")
run_process(
path,
target=_vizualize_from_location,
args=(args.graph_location,),
watch_filter=PythonFilter(),
recursive=True,
)