1
1
from attrs import field , frozen
2
2
from datetime import datetime
3
3
from metafold .api import asdatetime , asdict , optional_datetime
4
+ from metafold .assets import Asset
4
5
from metafold .client import Client
5
6
from metafold .exceptions import PollTimeout
7
+ from metafold .jobs import Job
6
8
from requests import Response
7
9
from typing import Optional , Union
10
+ import typing
11
+
12
+ if typing .TYPE_CHECKING :
13
+ from metafold import MetafoldClient
8
14
9
15
10
16
@frozen (kw_only = True )
@@ -21,6 +27,9 @@ class Workflow:
21
27
definition: Workflow definition string.
22
28
project_id: Project ID.
23
29
"""
30
+ _client : "MetafoldClient"
31
+ _jobs : dict [str , str ] = field (factory = dict , init = False )
32
+
24
33
id : str
25
34
jobs : list [str ] = field (factory = list )
26
35
state : str
@@ -32,6 +41,52 @@ class Workflow:
32
41
definition : str
33
42
project_id : str
34
43
44
+ def get_asset (self , path : str ) -> Asset | None :
45
+ """Retrieve an asset from the workflow by dot notation.
46
+
47
+ Args:
48
+ path: Path to asset in the form "job.name", e.g. "sample-mesh.volume"
49
+ searches for the asset "volume" from the "sample-mesh" job.
50
+ """
51
+ job_name , asset_name = self ._parse_path (path )
52
+ job = self ._find_job (job_name )
53
+ if not job or not job .outputs .assets :
54
+ return
55
+ for name , asset in job .outputs .assets .items ():
56
+ if name == asset_name :
57
+ return asset
58
+
59
+ def get_parameter (self , path : str ) -> str | None :
60
+ """Retrieve a parameter from the workflow by dot notation.
61
+
62
+ Args:
63
+ path: Path to parameter in the form "job.name", e.g. "sample-mesh.patch_size"
64
+ searches for the parameter "patch_size" from the "sample-mesh" job.
65
+ """
66
+ job_name , param_name = self ._parse_path (path )
67
+ job = self ._find_job (job_name )
68
+ if not job or not job .outputs .params :
69
+ return
70
+ for name , param in job .outputs .params .items ():
71
+ if name == param_name :
72
+ return param
73
+
74
+ def _find_job (self , name : str ) -> Job | None :
75
+ # FIXME(ryan): Update API to return job names as well as IDs.
76
+ # For now we cache a mapping b/w job name and job id.
77
+ if job_id := self ._jobs .get (name ):
78
+ return self ._client .jobs .get (job_id )
79
+
80
+ for job_id in self .jobs :
81
+ job = self ._client .jobs .get (job_id )
82
+ if job .name == name :
83
+ self ._jobs [name ] = job_id
84
+ return job
85
+
86
+ @staticmethod
87
+ def _parse_path (path : str ) -> tuple [str , str ]:
88
+ return path .split ("." , maxsplit = 1 )
89
+
35
90
36
91
class WorkflowsEndpoint :
37
92
"""Metafold workflows endpoint."""
@@ -61,7 +116,7 @@ def list(
61
116
url = f"/projects/{ project_id } /workflows"
62
117
payload = asdict (sort = sort , q = q )
63
118
r : Response = self ._client .get (url , params = payload )
64
- return [Workflow (** w ) for w in r .json ()]
119
+ return [Workflow (client = self . _client , ** w ) for w in r .json ()]
65
120
66
121
def get (self , workflow_id : str , project_id : Optional [str ] = None ) -> Workflow :
67
122
"""Get a workflow.
@@ -76,7 +131,7 @@ def get(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow:
76
131
project_id = self ._client .project_id (project_id )
77
132
url = f"/projects/{ project_id } /workflows/{ workflow_id } "
78
133
r : Response = self ._client .get (url )
79
- return Workflow (** r .json ())
134
+ return Workflow (client = self . _client , ** r .json ())
80
135
81
136
def run (
82
137
self , definition : str ,
@@ -110,7 +165,7 @@ def run(
110
165
raise RuntimeError (
111
166
f"Workflow failed to complete within { timeout } seconds"
112
167
) from e
113
- return Workflow (** r .json ())
168
+ return Workflow (client = self . _client , ** r .json ())
114
169
115
170
def cancel (self , workflow_id : str , project_id : Optional [str ] = None ) -> Workflow :
116
171
"""Cancel a running workflow.
@@ -125,7 +180,7 @@ def cancel(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow
125
180
project_id = self ._client .project_id (project_id )
126
181
url = f"/projects/{ project_id } /workflows/{ workflow_id } /cancel"
127
182
r : Response = self ._client .post (url )
128
- return Workflow (** r .json ())
183
+ return Workflow (client = self . _client , ** r .json ())
129
184
130
185
def delete (self , workflow_id : str , project_id : Optional [str ] = None ):
131
186
"""Delete a workflow.
0 commit comments