From 83e5dee50f11a4019eb9ae823655514159caadd4 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Tue, 12 Mar 2024 19:44:10 +0100 Subject: [PATCH] added new models --- dl_bench/bench/cnn.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/dl_bench/bench/cnn.py b/dl_bench/bench/cnn.py index 4504c57..0387292 100644 --- a/dl_bench/bench/cnn.py +++ b/dl_bench/bench/cnn.py @@ -12,6 +12,12 @@ def get_cnn(name): efficientnet_v2_m, mobilenet_v3_large, ) + from torchvision.models.segmentation import ( + fcn_resnet50, + lraspp_mobilenet_v3_large, + deeplabv3_resnet50, + ) + from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2 name2model = { "vgg16": vgg16, @@ -21,6 +27,12 @@ def get_cnn(name): "resnext101": resnext101_32x8d, "densenet121": densenet121, "mobilenet_v3l": mobilenet_v3_large, + # Segm + "fcn_resnet50": fcn_resnet50, + "lraspp_mobilenet_v3_large": lraspp_mobilenet_v3_large, + "deeplabv3_resnet50": deeplabv3_resnet50, + # Detection + "retinanet_resnet50_fpn_v2": retinanet_resnet50_fpn_v2, } if name in name2model: return name2model[name]()