Skip to content

Commit e824f3d

Browse files
thearchitectorSung96kim
authored andcommitted
fix: import/export typing
1 parent 69fb48a commit e824f3d

File tree

6 files changed

+63
-44
lines changed

6 files changed

+63
-44
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ You will also need the following env variables set for the Exchange integration
105105
2. Activate the virtual environment
106106
`source venv/bin/activate`
107107
3. Install the client
108-
`pip3 install --editable .`
108+
`pip3 install --editable .[all]`
109109
4. Install test deps
110110
`pip3 install "pytest<8" "requests-mock>=1.8.0" "pytest-asyncio>0.21"`
111111
5. Run tests

docker-compose.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
version: "3"
2-
31
services:
42
indico-client-build:
53
build:

indico/queries/model_export.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
from typing import TYPE_CHECKING
2+
13
from indico.client.request import Delay, GraphQLRequest, RequestChain
24
from indico.types.model_export import ModelExport
35

6+
if TYPE_CHECKING: # pragma: no cover
7+
from typing import Any, Iterator, List, Union
8+
9+
from indico.typing import Payload
10+
411

5-
class _CreateModelExport(GraphQLRequest):
12+
class _CreateModelExport(GraphQLRequest["ModelExport"]):
613
query = """
714
mutation ($modelId: Int!) {
815
createModelExport(
@@ -20,11 +27,11 @@ def __init__(self, model_id: int):
2027
self.model_id = model_id
2128
super().__init__(self.query, variables={"modelId": model_id})
2229

23-
def process_response(self, response) -> ModelExport:
24-
return ModelExport(**super().process_response(response)["createModelExport"])
30+
def process_response(self, response: "Payload") -> ModelExport:
31+
return ModelExport(**super().parse_payload(response)["createModelExport"])
2532

2633

27-
class CreateModelExport(RequestChain):
34+
class CreateModelExport(RequestChain["List[ModelExport]"]):
2835
"""
2936
Create a model export.
3037
@@ -36,20 +43,20 @@ class CreateModelExport(RequestChain):
3643
request_interval (int | float): the interval between requests in seconds. Defaults to 5.
3744
"""
3845

39-
previous: ModelExport | None = None
46+
previous: "Any" = None
4047

4148
def __init__(
4249
self,
4350
model_id: int,
4451
wait: bool = True,
45-
request_interval: int | float = 5,
52+
request_interval: "Union[int, float]" = 5,
4653
):
4754
self.wait = wait
4855
self.model_id = model_id
4956
self.request_interval = request_interval
5057
super().__init__()
5158

52-
def requests(self):
59+
def requests(self) -> "Iterator[Union[_CreateModelExport, Delay, GetModelExports]]":
5360
yield _CreateModelExport(self.model_id)
5461
if self.wait:
5562
while self.previous and self.previous.status not in ["COMPLETE", "FAILED"]:
@@ -60,7 +67,7 @@ def requests(self):
6067
yield GetModelExports([self.previous.id], with_signed_url=self.wait is True)
6168

6269

63-
class GetModelExports(GraphQLRequest):
70+
class GetModelExports(GraphQLRequest["List[ModelExport]"]):
6471
"""
6572
Get model export(s).
6673
@@ -91,17 +98,17 @@ class GetModelExports(GraphQLRequest):
9198
"createdBy",
9299
]
93100

94-
def __init__(self, export_ids: list[int], with_signed_url: bool = False):
101+
def __init__(self, export_ids: "List[int]", with_signed_url: bool = False):
95102
if with_signed_url:
96103
self._base_fields.append("signedUrl")
97104

98105
query_with_fields = self.query.replace("{fields}", "\n".join(self._base_fields))
99106
super().__init__(query_with_fields, variables={"exportIds": export_ids})
100107

101-
def process_response(self, response) -> list[ModelExport]:
108+
def process_response(self, response: "Payload") -> "List[ModelExport]":
102109
return [
103110
ModelExport(**export)
104-
for export in super().process_response(response)["modelExports"][
111+
for export in super().parse_payload(response)["modelExports"][
105112
"modelExports"
106113
]
107114
]

indico/queries/model_import.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1-
from typing import Generator
1+
from typing import TYPE_CHECKING, cast
22

33
import requests
44

55
from indico.client.request import GraphQLRequest, RequestChain
66
from indico.errors import IndicoInputError, IndicoRequestError
7-
from indico.queries.jobs import JobStatus
87
from indico.types.jobs import Job
98

9+
from .jobs import JobStatus
1010

11-
class _UploadSMExport(GraphQLRequest):
11+
if TYPE_CHECKING: # pragma: no cover
12+
from typing import Dict, Iterator, Optional, Union # noqa: F401
13+
14+
from indico.typing import Payload
15+
16+
17+
class _UploadSMExport(GraphQLRequest[str]):
1218
query = """
1319
query exportUpload {
1420
exportUpload {
@@ -22,25 +28,26 @@ def __init__(self, file_path: str):
2228
self.file_path = file_path
2329
super().__init__(self.query)
2430

25-
def process_response(self, response) -> str:
26-
resp = super().process_response(response)["exportUpload"]
31+
def process_response(self, response: "Payload") -> str:
32+
resp: "Dict[str, str]" = super().parse_payload(response)["exportUpload"]
2733
signed_url = resp["signedUrl"]
2834
storage_uri = resp["storageUri"]
2935

3036
with open(self.file_path, "rb") as file:
3137
file_content = file.read()
3238

3339
headers = {"Content-Type": "application/zip"}
34-
response = requests.put(signed_url, data=file_content, headers=headers)
40+
export_response = requests.put(signed_url, data=file_content, headers=headers)
3541

36-
if response.status_code != 200:
42+
if export_response.status_code != 200:
3743
raise IndicoRequestError(
38-
f"Failed to upload static model export: {response.text}"
44+
f"Failed to upload static model export: {export_response.text}",
45+
export_response.status_code,
3946
)
4047
return storage_uri
4148

4249

43-
class ProcessStaticModelExport(GraphQLRequest):
50+
class ProcessStaticModelExport(GraphQLRequest["Job"]):
4451
"""
4552
Process a static model export.
4653
@@ -77,12 +84,12 @@ def __init__(
7784
},
7885
)
7986

80-
def process_response(self, response) -> Job:
81-
job_id = super().process_response(response)["processStaticModelExport"]["jobId"]
87+
def process_response(self, response: "Payload") -> Job:
88+
job_id = super().parse_payload(response)["processStaticModelExport"]["jobId"]
8289
return Job(id=job_id)
8390

8491

85-
class UploadStaticModelExport(RequestChain):
92+
class UploadStaticModelExport(RequestChain["Union[Job, str]"]):
8693
"""
8794
Upload a static model export to Indico.
8895
@@ -100,22 +107,27 @@ class UploadStaticModelExport(RequestChain):
100107
"""
101108

102109
def __init__(
103-
self, file_path: str, auto_process: bool = False, workflow_id: int | None = None
110+
self,
111+
file_path: str,
112+
auto_process: bool = False,
113+
workflow_id: "Optional[int]" = None,
104114
):
105-
self.file_path = file_path
106-
self.auto_process = auto_process
107-
if auto_process and not workflow_id:
115+
if auto_process and workflow_id is None:
108116
raise IndicoInputError(
109117
"Must provide `workflow_id` if `auto_process` is True."
110118
)
111119

120+
self.file_path = file_path
121+
self.auto_process = auto_process
112122
self.workflow_id = workflow_id
113123

114-
def requests(self) -> Generator[str | Job, None, None]:
124+
def requests(
125+
self,
126+
) -> "Iterator[Union[_UploadSMExport, ProcessStaticModelExport, JobStatus]]":
115127
if self.auto_process:
116128
yield _UploadSMExport(self.file_path)
117129
yield ProcessStaticModelExport(
118-
storage_uri=self.previous, workflow_id=self.workflow_id
130+
storage_uri=self.previous, workflow_id=cast(int, self.workflow_id)
119131
)
120132
yield JobStatus(self.previous.id)
121133
if self.previous.status == "FAILURE":

indico/queries/workflow_components.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, cast
22

33
import jsons
44

@@ -13,7 +13,7 @@
1313
)
1414

1515
if TYPE_CHECKING: # pragma: no cover
16-
from typing import Iterator, List, Optional, Union
16+
from typing import Any, Iterator, List, Optional, Union
1717

1818
from indico.typing import AnyDict, Payload
1919

@@ -458,7 +458,7 @@ def process_response(self, response: "Payload") -> "Workflow":
458458
)
459459

460460

461-
class AddStaticModelComponent(RequestChain):
461+
class AddStaticModelComponent(RequestChain["Workflow"]):
462462
"""
463463
Add a static model component to a workflow.
464464
@@ -473,13 +473,13 @@ class AddStaticModelComponent(RequestChain):
473473
`export_file(str)`: the path to the static model export file.
474474
"""
475475

476-
previous = None
476+
previous: "Any" = None
477477

478478
def __init__(
479479
self,
480480
workflow_id: int,
481-
after_component_id: int | None = None,
482-
after_component_link_id: int | None = None,
481+
after_component_id: "Optional[int]" = None,
482+
after_component_link_id: "Optional[int]" = None,
483483
static_component_config: "Optional[AnyDict]" = None,
484484
component_name: "Optional[str]" = None,
485485
auto_process: bool = False,
@@ -514,11 +514,13 @@ def __init__(
514514
self.auto_process = auto_process
515515
self.export_file = export_file
516516

517-
def requests(self):
517+
def requests(
518+
self,
519+
) -> "Iterator[Union[UploadStaticModelExport, _AddWorkflowComponent]]":
518520
if self.auto_process:
519521
yield UploadStaticModelExport(
520522
auto_process=True,
521-
file_path=self.export_file,
523+
file_path=cast(str, self.export_file),
522524
workflow_id=self.workflow_id,
523525
)
524526
self.component.update(

tests/integration/queries/test_workflow_component.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
ModelGroup,
2121
ModelTaskType,
2222
NewLabelsetArguments,
23-
StaticModelConfig,
23+
# StaticModelConfig,
2424
)
2525

2626
from ..data.datasets import * # noqa
@@ -257,9 +257,9 @@ def test_add_static_model_component(indico, org_annotate_dataset):
257257
static_model_req = AddStaticModelComponent(
258258
workflow_id=wf.id,
259259
after_component_id=after_component_id,
260-
static_component_config=StaticModelConfig(
261-
export_meta=finished_job.result,
262-
),
260+
static_component_config={
261+
"export_meta": finished_job.result,
262+
},
263263
)
264264
wf = client.call(static_model_req)
265265

0 commit comments

Comments
 (0)