|
1 | | -import gc |
2 | 1 | from typing import List |
3 | 2 | from datasets import Dataset |
4 | 3 | from vllm import LLM, SamplingParams |
5 | | -from utils import generate_prompt |
| 4 | +from utils import generate_prompt, cleanup |
6 | 5 |
|
7 | 6 |
|
8 | | -def cleanup(model): |
9 | | - try: |
10 | | - import torch |
11 | | - import contextlib |
12 | | - if torch.cuda.is_available(): |
13 | | - from vllm.distributed.parallel_state import ( |
14 | | - destroy_model_parallel, destroy_distributed_environment |
15 | | - ) |
16 | | - destroy_model_parallel() |
17 | | - destroy_distributed_environment() |
18 | | - del model.llm_engine.model_executor |
19 | | - del model |
20 | | - with contextlib.suppress(AssertionError): |
21 | | - torch.distributed.destroy_process_group() |
22 | | - gc.collect() |
23 | | - torch.cuda.empty_cache() |
24 | | - torch.cuda.synchronize() |
25 | | - except ImportError: |
26 | | - del model |
27 | | - |
28 | 7 | def generate_predictions( |
29 | 8 | model_name: str, dataset: Dataset, temperature: float = 1.0, n: int = 1 |
30 | 9 | ) -> List[List[str]]: |
@@ -62,8 +41,5 @@ def generate_predictions( |
62 | 41 | for output in outputs: |
63 | 42 | generated_texts = [one.text for one in output.outputs] |
64 | 43 | results.append(generated_texts) |
65 | | - cleanup(llm) |
| 44 | + cleanup(llm, vllm=True) |
66 | 45 | return results |
67 | | - # out_name = dataset_name.split("/")[-1] |
68 | | - # out_name = f"wentingzhao/{out_name}_predictions_{n}" |
69 | | - # ds.push_to_hub(out_name) |
0 commit comments