diff --git a/requirements.txt b/requirements.txt index bf9f390..8d506bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ fire loguru datasets typer -rich +rich \ No newline at end of file diff --git a/tests.py b/tests.py index 754f161..e603ea9 100644 --- a/tests.py +++ b/tests.py @@ -3,18 +3,25 @@ generate_openai, inject_references_to_messages, generate_with_references, + generate_hf, ) - if __name__ == "__main__": - ##### + #### messages = [{"role": "user", "content": "hello!"}] - output = generate_together( - "meta-llama/Llama-3-8b-chat-hf", - messages, - temperature=0, - ) + try: + output = generate_together( + "meta-llama/Llama-3-8b-chat-hf", + messages, + temperature=0, + ) + except: + output = generate_hf( + "meta-llama/Meta-Llama-3-8B-Instruct", + messages, + temperature=0, + ) assert ( output.strip() == "Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?" @@ -39,11 +46,18 @@ ) assert len(messages) == 2 assert messages[0]["role"] == "system" - output = generate_together( - "meta-llama/Llama-3-8b-chat-hf", - messages, - temperature=0, - ) + try: + output = generate_together( + "meta-llama/Llama-3-8b-chat-hf", + messages, + temperature=0, + ) + except: + output = generate_hf( + "meta-llama/Meta-Llama-3-8B-Instruct", + messages, + temperature=0, + ) assert ( output.strip() == "Hello! It seems like you're looking for assistance with something. I'm here to help! Could you please provide more context or clarify what's on your mind? I'll do my best to offer a helpful and accurate response." diff --git a/utils.py b/utils.py index 4651745..9639204 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,11 @@ import requests import openai import copy +try: + import transformers + import torch +except: + pass from loguru import logger @@ -133,6 +138,40 @@ def generate_openai( return output +def generate_hf( + model, + messages, + max_tokens=2048, + temperature=0.7, +): + + tokenizer = transformers.AutoTokenizer.from_pretrained(model) + + bnb_config = transformers.BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) + + pipeline = transformers.pipeline( + "text-generation", + model=model, + device_map="auto", + tokenizer=tokenizer, + model_kwargs={ + "torch_dtype": torch.float16, + "quantization_config": bnb_config, + "low_cpu_mem_usage": True, + },) + + output = pipeline( + messages, + max_new_tokens=max_tokens, + do_sample=(temperature > 0), + temperature=temperature, + ) + + return output[0]["generated_text"][-1]["content"] def inject_references_to_messages( messages, @@ -178,4 +217,4 @@ def generate_with_references( messages=messages, temperature=temperature, max_tokens=max_tokens, - ) + ) \ No newline at end of file