Skip to content

Commit aea466b

Browse files
committed
Add volume_size parameter to ComputeCurvatures
1 parent f93852f commit aea466b

File tree

7 files changed

+45
-18
lines changed

7 files changed

+45
-18
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ jobs:
2020
- name: Install dependencies
2121
run: |
2222
python -m pip install --upgrade pip
23-
python -m pip install -r requirements.txt
24-
python -m pip install requests-toolbelt ruff mypy typing_extensions types-networkx types-requests pytest networkx
23+
python -m pip install -r requirements-test.txt
2524
- name: Lint with Ruff
2625
run: ruff check --output-format=github .
2726
- run: |

examples/run_workflow.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Helper script to dispatch workflows."""
22
from argparse import ArgumentParser, FileType
33
from metafold import MetafoldClient
4+
from pathlib import Path
5+
from pprint import pprint
46
import json
57
import os
68
import sys
@@ -17,7 +19,7 @@ def main():
1719

1820
parser.add_argument(
1921
"--asset-uploads", nargs="*",
20-
type=FileType("rb"), help="assets to upload before dispatch")
22+
type=Path, help="assets to upload before dispatch")
2123

2224
project_id = os.environ.get("METAFOLD_PROJECT_ID")
2325
parser.add_argument("-p", "--project-id", default=project_id)
@@ -60,19 +62,24 @@ def main():
6062

6163
if args.asset_uploads:
6264
print("Uploading assets…")
63-
for f in args.asset_uploads:
64-
m.assets.create(f)
65+
for p in args.asset_uploads:
66+
m.assets.create(p.resolve())
6567

6668
print("Running workflow…")
6769
definition = args.workflow.read()
6870
w = m.workflows.run(definition, assets=assets, parameters=params)
6971

7072
print(f"Workflow completed: {w.state}")
7173

72-
if w.state == "failure":
73-
for job_id in w.jobs:
74-
j = m.jobs.get(job_id)
75-
if j.state == "failure":
74+
for job_id in w.jobs:
75+
j = m.jobs.get(job_id)
76+
match j.state:
77+
case "success":
78+
if j.outputs and j.outputs.assets:
79+
pprint(j.outputs.assets)
80+
if j.outputs and j.outputs.params:
81+
pprint(j.outputs.params)
82+
case "failure":
7683
print(f"Job {j.id} failed: {j.error}")
7784

7885
return 0

metafold/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def optional(f: Callable[[T], U]) -> Callable[[Optional[T]], Optional[U]]:
3535
@wraps(f)
3636
def decorator(v: Optional[T]) -> Optional[U]:
3737
if v is None:
38-
return v
38+
return None
3939
return f(v)
4040

4141
return decorator

metafold/func.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __call__(self, eval_: Evaluator) -> TypedResult[Literal[FuncType.FLOAT]]:
106106
class ComputeCurvatures_Parameters(TypedDict, total=False):
107107
spacing_type: ComputeCurvatures_Enum_spacing_type
108108
step_size: float
109+
volume_size: Vec3f
109110

110111

111112
class ComputeCurvatures(TypedFunc[Literal[FuncType.VEC3F]]):

metafold/jobs.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from metafold.client import Client
66
from metafold.exceptions import PollTimeout
77
from requests import Response
8-
from typing import Any, Optional, TypeAlias, TypedDict, Union
8+
from typing import Any, Optional, TypedDict, Union
9+
from typing_extensions import TypeAlias
910

1011

1112
def _assets(v: list[Union[dict[str, Any], Asset]]) -> list[Asset]:
@@ -37,7 +38,7 @@ class IO:
3738
"""
3839
params: Optional[dict[str, Any]] = None
3940
assets: Optional[dict[str, Asset]] = field(
40-
converter=optional(_assets_dict), default=None)
41+
converter=lambda v: optional(_assets_dict)(v), default=None)
4142

4243
@staticmethod
4344
def from_dict(d: IODict) -> "IO":
@@ -69,13 +70,16 @@ class Job:
6970
type: str
7071
state: str
7172
created: datetime = field(converter=asdatetime)
72-
started: Optional[datetime] = field(converter=optional_datetime, default=None)
73-
finished: Optional[datetime] = field(converter=optional_datetime, default=None)
73+
started: Optional[datetime] = field(
74+
converter=lambda v: optional_datetime(v), default=None)
75+
finished: Optional[datetime] = field(
76+
converter=lambda v: optional_datetime(v), default=None)
7477
error: Optional[str] = None
7578
inputs: IO = field(converter=lambda v: v if isinstance(v, IO) else IO.from_dict(v))
7679
outputs: IO = field(converter=lambda v: v if isinstance(v, IO) else IO.from_dict(v))
7780
# NOTE(ryan): Deprecated
78-
assets: Optional[list[Asset]] = field(converter=optional(_assets), default=None)
81+
assets: Optional[list[Asset]] = field(
82+
converter=lambda v: optional(_assets)(v), default=None)
7983
parameters: dict[str, Any]
8084
meta: dict[str, Any]
8185

metafold/workflows.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ class Workflow:
2424
jobs: list[str] = field(factory=list)
2525
state: str
2626
created: datetime = field(converter=asdatetime)
27-
started: Optional[datetime] = field(converter=optional_datetime, default=None)
28-
finished: Optional[datetime] = field(converter=optional_datetime, default=None)
27+
started: Optional[datetime] = field(
28+
converter=lambda v: optional_datetime(v), default=None)
29+
finished: Optional[datetime] = field(
30+
converter=lambda v: optional_datetime(v), default=None)
2931
definition: str
3032

3133

@@ -101,7 +103,7 @@ def run(
101103
r: Response = self._client.post(f"/projects/{project_id}/workflows", json=payload)
102104
url = r.json()["link"]
103105
try:
104-
r: Response = self._client.poll(url, timeout)
106+
r = self._client.poll(url, timeout)
105107
except PollTimeout as e:
106108
raise RuntimeError(
107109
f"Workflow failed to complete within {timeout} seconds"

requirements-test.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
attrs~=23.2
2+
auth0-python~=4.7
3+
numpy~=1.23
4+
requests~=2.31
5+
scipy~=1.11
6+
7+
mypy
8+
networkx
9+
pytest
10+
requests-toolbelt
11+
ruff
12+
types-networkx
13+
types-requests
14+
typing_extensions

0 commit comments

Comments
 (0)