Skip to content

Commit 1a9acf5

Browse files
committed
add kwargs in predict and predict_proba
1 parent 3f22697 commit 1a9acf5

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

src/autogluon/cloud/backend/sagemaker_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,7 @@ def predict(
704704
instance_count: int = 1,
705705
custom_image_uri: Optional[str] = None,
706706
wait: bool = True,
707+
inference_kwargs: Optional[Dict[str, Any]] = None,
707708
download: bool = True,
708709
persist: bool = True,
709710
save_path: Optional[str] = None,
@@ -783,6 +784,7 @@ def predict(
783784
instance_count=instance_count,
784785
custom_image_uri=custom_image_uri,
785786
wait=wait,
787+
inference_kwargs=inference_kwargs,
786788
download=download,
787789
persist=persist,
788790
save_path=save_path,
@@ -805,6 +807,7 @@ def predict_proba(
805807
instance_count: int = 1,
806808
custom_image_uri: Optional[str] = None,
807809
wait: bool = True,
810+
inference_kwargs: Optional[Dict[str, Any]] = None,
808811
download: bool = True,
809812
persist: bool = True,
810813
save_path: Optional[str] = None,
@@ -889,6 +892,7 @@ def predict_proba(
889892
instance_count=instance_count,
890893
custom_image_uri=custom_image_uri,
891894
wait=wait,
895+
inference_kwargs=inference_kwargs,
892896
download=download,
893897
persist=persist,
894898
save_path=save_path,
@@ -1133,6 +1137,7 @@ def _predict(
11331137
instance_count=1,
11341138
custom_image_uri=None,
11351139
wait=True,
1140+
inference_kwargs=None,
11361141
download=True,
11371142
persist=True,
11381143
save_path=None,
@@ -1256,6 +1261,7 @@ def _predict(
12561261
transformer_kwargs=transformer_kwargs,
12571262
model_kwargs=model_kwargs,
12581263
repack_model=repack_model,
1264+
inference_kwargs=inference_kwargs,
12591265
**transform_kwargs,
12601266
)
12611267
self._batch_transform_jobs[job_name] = batch_transform_job

src/autogluon/cloud/job/sagemaker_job.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from abc import abstractmethod
33
from typing import Optional
4-
4+
import json
55
import sagemaker
66

77
from ..utils.ag_sagemaker import (
@@ -257,6 +257,7 @@ def run(
257257
model_kwargs,
258258
transformer_kwargs,
259259
repack_model=False,
260+
inference_kwargs=None,
260261
**kwargs,
261262
):
262263
self._local_mode = instance_type in (LOCAL_MODE, LOCAL_MODE_GPU)
@@ -265,6 +266,10 @@ def run(
265266
else:
266267
model_cls = AutoGluonNonRepackInferenceModel
267268
logger.log(20, "Creating inference model...")
269+
inference_kwargs_str = json.dumps(inference_kwargs) if inference_kwargs is not None else None
270+
env = {}
271+
if len(inference_kwargs_str) > 0:
272+
env["inference_kwargs"] = inference_kwargs_str
268273
model = model_cls(
269274
model_data=model_data,
270275
role=role,
@@ -275,6 +280,7 @@ def run(
275280
custom_image_uri=custom_image_uri,
276281
entry_point=entry_point,
277282
predictor_cls=predictor_cls,
283+
env=env,
278284
**model_kwargs,
279285
)
280286
logger.log(20, "Inference model created successfully")

src/autogluon/cloud/predictor/cloud_predictor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ def predict(
556556
custom_image_uri: Optional[str] = None,
557557
wait: bool = True,
558558
backend_kwargs: Optional[Dict] = None,
559+
**kwargs,
559560
) -> Optional[pd.Series]:
560561
"""
561562
Batch inference.
@@ -632,6 +633,7 @@ def predict(
632633
instance_count=instance_count,
633634
custom_image_uri=custom_image_uri,
634635
wait=wait,
636+
inference_kwargs=kwargs,
635637
**backend_kwargs,
636638
)
637639

@@ -648,6 +650,7 @@ def predict_proba(
648650
custom_image_uri: Optional[str] = None,
649651
wait: bool = True,
650652
backend_kwargs: Optional[Dict] = None,
653+
**kwargs,
651654
) -> Optional[Union[Tuple[pd.Series, Union[pd.DataFrame, pd.Series]], Union[pd.DataFrame, pd.Series]]]:
652655
"""
653656
Batch inference
@@ -730,6 +733,7 @@ def predict_proba(
730733
instance_count=instance_count,
731734
custom_image_uri=custom_image_uri,
732735
wait=wait,
736+
inference_kwargs=kwargs,
733737
**backend_kwargs,
734738
)
735739

0 commit comments

Comments
 (0)