Skip to content

Commit 0552dbe

Browse files
authored
[Fix] Small fixes to account for changes to RegistryMixin (#1558)
1 parent 79abf8d commit 0552dbe

File tree

5 files changed

+6
-61
lines changed

5 files changed

+6
-61
lines changed

src/deepsparse/operators/registry.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,10 @@ class OperatorRegistry(RegistryMixin):
3333
"""
3434

3535
@classmethod
36-
def register_value(cls, operator, name):
36+
def register_value(cls, operator, name, alias):
3737
from deepsparse.operators import Operator
3838

39-
if not isinstance(name, list):
40-
name = [name]
41-
42-
for task_name in name:
43-
register(Operator, operator, task_name, require_subclass=True)
39+
register(Operator, operator, name, alias, require_subclass=True)
4440

4541
return operator
4642

src/deepsparse/transformers/pipelines/code_generation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
__all__ = ["CodeGenerationPipeline"]
2121

2222

23-
@OperatorRegistry.register(name=["code_generation", "code_gen", "codegen"])
23+
@OperatorRegistry.register(name="code_generation", alias=["code_gen", "codegen"])
2424
class CodeGenerationPipeline(TextGenerationPipeline):
2525
"""
2626
Subclass of text generation pipeline to support any defaults or

src/deepsparse/transformers/pipelines/text_generation/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
_LOGGER = logging.getLogger(__name__)
5555

5656

57-
@OperatorRegistry.register(name=["text_generation", "opt", "mpt", "llama"])
57+
@OperatorRegistry.register(name="text_generation", alias=["opt", "mpt", "llama"])
5858
class TextGenerationPipeline(Pipeline):
5959
DEFAULT_SEQUENCE_LENGTH = 1024
6060

tests/deepsparse/evaluation/test_registry.py

-51
This file was deleted.

tests/test_pipeline_benchmark.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
),
5353
(
5454
"image_classification",
55-
"zoo:cv/classification/resnet_v1-50_2x/pytorch/sparseml/imagenet/base-none",
55+
"zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/base-none",
5656
[
5757
"-c",
5858
"tests/test_data/pipeline_bench_config.json",
@@ -66,7 +66,7 @@
6666
),
6767
(
6868
"image_classification",
69-
"zoo:cv/classification/resnet_v1-50_2x/pytorch/sparseml/imagenet/base-none",
69+
"zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/base-none",
7070
[],
7171
),
7272
(

0 commit comments

Comments
 (0)