From 2cebfaff5869543cfb7a4aae97d5257d6478c055 Mon Sep 17 00:00:00 2001 From: Evan Liu Date: Wed, 29 Jan 2025 20:24:32 -0500 Subject: [PATCH] type hinting for model_utils --- src/utils/model_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index 97c639eb..062445e5 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Any, Tuple, List +from typing import Any, Tuple, List, Optional import torch from torch import Tensor import torch.nn as nn @@ -267,12 +267,12 @@ def train_classification( def train_classification_malicious( self, model: nn.Module, - optim, - dloader, - loss_fn, + optim: torch.optim.Optimizer, + dloader: DataLoader[Any], + loss_fn: Any, device: torch.device, - test_loader=None, - **kwargs, + test_loader: Optional[DataLoader[Any]]=None, + **kwargs: Any, ) -> Tuple[float, float]: correct = 0 train_loss = 0