Skip to content

Commit ccb25e3

Browse files
committed
Add support for fetching job assets/params from workflows
1 parent f85958c commit ccb25e3

File tree

2 files changed

+109
-4
lines changed

2 files changed

+109
-4
lines changed

metafold/workflows.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from attrs import field, frozen
22
from datetime import datetime
33
from metafold.api import asdatetime, asdict, optional_datetime
4+
from metafold.assets import Asset
45
from metafold.client import Client
56
from metafold.exceptions import PollTimeout
7+
from metafold.jobs import Job
68
from requests import Response
79
from typing import Optional, Union
10+
import typing
11+
12+
if typing.TYPE_CHECKING:
13+
from metafold import MetafoldClient
814

915

1016
@frozen(kw_only=True)
@@ -21,6 +27,9 @@ class Workflow:
2127
definition: Workflow definition string.
2228
project_id: Project ID.
2329
"""
30+
_client: "MetafoldClient"
31+
_jobs: dict[str, str] = field(factory=dict, init=False)
32+
2433
id: str
2534
jobs: list[str] = field(factory=list)
2635
state: str
@@ -32,6 +41,52 @@ class Workflow:
3241
definition: str
3342
project_id: str
3443

44+
def get_asset(self, path: str) -> Asset | None:
45+
"""Retrieve an asset from the workflow by dot notation.
46+
47+
Args:
48+
path: Path to asset in the form "job.name", e.g. "sample-mesh.volume"
49+
searches for the asset "volume" from the "sample-mesh" job.
50+
"""
51+
job_name, asset_name = self._parse_path(path)
52+
job = self._find_job(job_name)
53+
if not job or not job.outputs.assets:
54+
return
55+
for name, asset in job.outputs.assets.items():
56+
if name == asset_name:
57+
return asset
58+
59+
def get_parameter(self, path: str) -> str | None:
60+
"""Retrieve a parameter from the workflow by dot notation.
61+
62+
Args:
63+
path: Path to parameter in the form "job.name", e.g. "sample-mesh.patch_size"
64+
searches for the parameter "patch_size" from the "sample-mesh" job.
65+
"""
66+
job_name, param_name = self._parse_path(path)
67+
job = self._find_job(job_name)
68+
if not job or not job.outputs.params:
69+
return
70+
for name, param in job.outputs.params.items():
71+
if name == param_name:
72+
return param
73+
74+
def _find_job(self, name: str) -> Job | None:
75+
# FIXME(ryan): Update API to return job names as well as IDs.
76+
# For now we cache a mapping b/w job name and job id.
77+
if job_id := self._jobs.get(name):
78+
return self._client.jobs.get(job_id)
79+
80+
for job_id in self.jobs:
81+
job = self._client.jobs.get(job_id)
82+
if job.name == name:
83+
self._jobs[name] = job_id
84+
return job
85+
86+
@staticmethod
87+
def _parse_path(path: str) -> tuple[str, str]:
88+
return path.split(".", maxsplit=1)
89+
3590

3691
class WorkflowsEndpoint:
3792
"""Metafold workflows endpoint."""
@@ -61,7 +116,7 @@ def list(
61116
url = f"/projects/{project_id}/workflows"
62117
payload = asdict(sort=sort, q=q)
63118
r: Response = self._client.get(url, params=payload)
64-
return [Workflow(**w) for w in r.json()]
119+
return [Workflow(client=self._client, **w) for w in r.json()]
65120

66121
def get(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow:
67122
"""Get a workflow.
@@ -76,7 +131,7 @@ def get(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow:
76131
project_id = self._client.project_id(project_id)
77132
url = f"/projects/{project_id}/workflows/{workflow_id}"
78133
r: Response = self._client.get(url)
79-
return Workflow(**r.json())
134+
return Workflow(client=self._client, **r.json())
80135

81136
def run(
82137
self, definition: str,
@@ -110,7 +165,7 @@ def run(
110165
raise RuntimeError(
111166
f"Workflow failed to complete within {timeout} seconds"
112167
) from e
113-
return Workflow(**r.json())
168+
return Workflow(client=self._client, **r.json())
114169

115170
def cancel(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow:
116171
"""Cancel a running workflow.
@@ -125,7 +180,7 @@ def cancel(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow
125180
project_id = self._client.project_id(project_id)
126181
url = f"/projects/{project_id}/workflows/{workflow_id}/cancel"
127182
r: Response = self._client.post(url)
128-
return Workflow(**r.json())
183+
return Workflow(client=self._client, **r.json())
129184

130185
def delete(self, workflow_id: str, project_id: Optional[str] = None):
131186
"""Delete a workflow.

tests/test_workflows.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,35 @@
5353
"project_id": "1",
5454
}
5555

56+
mock_job = {
57+
"type": "test_job",
58+
"state": "success",
59+
"created": "Mon, 01 Jan 2024 00:00:00 GMT",
60+
"started": "Mon, 01 Jan 2024 00:00:00 GMT",
61+
"finished": "Mon, 01 Jan 2024 00:00:00 GMT",
62+
"error": None,
63+
"inputs": {
64+
"params": None,
65+
},
66+
"outputs": {
67+
"params": {
68+
"foo": "1",
69+
"bar": "a",
70+
"baz": "[2, \"b\"]",
71+
},
72+
},
73+
"needs": [],
74+
"project_id": "1",
75+
"workflow_id": None,
76+
"assets": [],
77+
"parameters": {
78+
"foo": "1",
79+
"bar": "a",
80+
"baz": "[2, \"b\"]",
81+
},
82+
"meta": None,
83+
}
84+
5685
poll_count: int = 0
5786

5887

@@ -90,6 +119,20 @@ def do_GET(self):
90119
self.send_header("Content-Type", "application/json")
91120
self.end_headers()
92121
self.wfile.write(json.dumps(payload).encode())
122+
elif u.path == "/projects/1/jobs/1":
123+
self.send_response(HTTPStatus.OK)
124+
self.send_header("Content-Type", "application/json")
125+
self.end_headers()
126+
payload = deepcopy(mock_job)
127+
payload.update({"id": "1", "name": "test-job-1"})
128+
self.wfile.write(json.dumps(payload).encode())
129+
elif u.path == "/projects/1/jobs/2":
130+
self.send_response(HTTPStatus.OK)
131+
self.send_header("Content-Type", "application/json")
132+
self.end_headers()
133+
payload = deepcopy(mock_job)
134+
payload.update({"id": "2", "name": "test-job-2"})
135+
self.wfile.write(json.dumps(payload).encode())
93136
else:
94137
self.send_error(HTTPStatus.NOT_FOUND)
95138

@@ -131,6 +174,7 @@ def test_list_workflows_filtered(client):
131174
def test_get_workflow(client):
132175
w = client.workflows.get("1")
133176
assert w == Workflow(
177+
client=client,
134178
id="1",
135179
jobs=["1", "2"],
136180
state="success",
@@ -141,11 +185,17 @@ def test_get_workflow(client):
141185
project_id="1",
142186
)
143187

188+
# Find params
189+
assert w.get_parameter("test-job-2.foo") == "1"
190+
assert w.get_parameter("test-job-2.foo") == "1" # Should use cached id
191+
assert w.get_parameter("test-job-2.bar") == "a"
192+
144193

145194
def test_run_workflow(client):
146195
definition = "foo"
147196
w = client.workflows.run(definition)
148197
assert w == Workflow(
198+
client=client,
149199
id="1",
150200
jobs=["1", "2"],
151201
state="success",

0 commit comments

Comments
 (0)