Skip to content

Commit 25efe76

Browse files
davidberenstein1957begumcig
authored andcommitted
fix: update model card tags to include 'pruna-ai' for improved categorization (#334)
1 parent 13088b3 commit 25efe76

File tree

4 files changed

+48
-13
lines changed

4 files changed

+48
-13
lines changed

src/pruna/engine/save.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def save_pruna_model_to_hub(
170170
template_path = Path(__file__).parent / "hf_hub_utils" / "model_card_template.md"
171171
# Get the pruna library version from initalized module as OSS or paid so we can use the same method for both
172172
pruna_library = instance.__module__.split(".")[0] if "." in instance.__module__ else None
173-
model_card_data["tags"] = [f"{pruna_library}-ai", "safetensors"]
173+
model_card_data["tags"] = list({f"{pruna_library}-ai", "safetensors", "pruna-ai"})
174174
# Build the template parameters dictionary for clarity and maintainability
175175
template_params: dict = {
176176
"repo_id": repo_id,

src/pruna/evaluation/evaluation_agent.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def evaluate(self, model: Any) -> List[MetricResult]:
108108
pruna_logger.info("Evaluating isolated inference metrics.")
109109
results.extend(self.compute_stateless_metrics(model, stateless_metrics))
110110

111-
model.move_to_device("cpu")
112111
safe_memory_cleanup()
113112
if self.evaluation_for_first_model:
114113
self.first_model_results = results
@@ -154,7 +153,8 @@ def prepare_model(self, model: Any) -> PrunaModel:
154153
pruna_logger.info("Evaluating a base model.")
155154
is_base = True
156155

157-
model.inference_handler.log_model_info()
156+
if hasattr(model, "inference_handler"): # Distributers do not have an inference handler
157+
model.inference_handler.log_model_info()
158158
if (
159159
"batch_size" in self.task.datamodule.dataloader_args
160160
and self.task.datamodule.dataloader_args["batch_size"] != model.smash_config.batch_size
@@ -169,9 +169,6 @@ def prepare_model(self, model: Any) -> PrunaModel:
169169
model.smash_config.batch_size,
170170
)
171171

172-
# ensure the model is on the cpu
173-
model.move_to_device("cpu")
174-
175172
return model
176173

177174
def update_stateful_metrics(
@@ -199,7 +196,6 @@ def update_stateful_metrics(
199196
if not single_stateful_metrics and not pairwise_metrics:
200197
return
201198

202-
model.move_to_device(self.device)
203199
for batch_idx, batch in enumerate(self.task.dataloader):
204200
processed_outputs = model.run_inference(batch, self.device)
205201

src/pruna/evaluation/metrics/metric_elapsed_time.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,7 @@ def compute(self, model: PrunaModel, dataloader: DataLoader) -> Dict[str, Any] |
165165
model,
166166
dataloader,
167167
self.n_warmup_iterations,
168-
lambda m, x: (
169-
m(**x, **m.inference_handler.model_args) # x is a dict
170-
if isinstance(x, dict)
171-
else m(x, **m.inference_handler.model_args) # x is tensor/list
172-
),
168+
lambda m, x: (m.run_inference(x)),
173169
)
174170

175171
# Measurement

tests/common.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import ast
12
import importlib.util
23
import inspect
34
import subprocess
5+
import textwrap
46
from pathlib import Path
57
from typing import Any, Callable
68

@@ -85,7 +87,11 @@ def run_full_integration(
8587
smashed_model = algorithm_tester.execute_smash(model, smash_config)
8688
algorithm_tester.execute_save(smashed_model)
8789
safe_memory_cleanup()
88-
reloaded_model = algorithm_tester.execute_load()
90+
reloaded_model = (
91+
smashed_model
92+
if is_function_unimplemented(algorithm_tester.execute_load)
93+
else algorithm_tester.execute_load()
94+
) # noqa: E501
8995
if device != "accelerate" and not skip_evaluation:
9096
algorithm_tester.execute_evaluation(reloaded_model, smash_config.data, smash_config["device"])
9197
if hasattr(reloaded_model, "destroy"):
@@ -296,3 +302,40 @@ def extract_code_blocks_from_node(node: Any, section_name: str) -> None:
296302
extract_code_blocks_from_node(sec, section_title)
297303

298304
print(f"Code blocks extracted and written to {output_dir}")
305+
306+
307+
def is_function_unimplemented(func):
308+
"""Check if a function is unimplemented."""
309+
source = inspect.getsource(func)
310+
source = textwrap.dedent(source)
311+
tree = ast.parse(source)
312+
313+
func_def = tree.body[0]
314+
if not isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef)):
315+
return False
316+
317+
# remove docstring if present
318+
body = func_def.body
319+
if (
320+
body
321+
and isinstance(body[0], ast.Expr)
322+
and isinstance(body[0].value, ast.Constant)
323+
and isinstance(body[0].value.value, str)
324+
):
325+
body = body[1:]
326+
327+
if len(body) == 1:
328+
stmt = body[0]
329+
if isinstance(stmt, ast.Pass): # pass is not implemented
330+
return True
331+
# ... is not implemented
332+
if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Constant) and stmt.value.value == Ellipsis:
333+
return True
334+
if (
335+
isinstance(stmt, ast.Raise)
336+
and isinstance(stmt.exc, ast.Call)
337+
and getattr(stmt.exc.func, "id", "") == "NotImplementedError"
338+
): # noqa: E501
339+
return True
340+
341+
return False

0 commit comments

Comments
 (0)