From 8a4bdc5d06693acfe5abd39b3ca261bf8c224a4f Mon Sep 17 00:00:00 2001 From: keetrap Date: Sat, 1 Feb 2025 18:26:56 +0530 Subject: [PATCH] Enable prompt push to hub --- src/smolagents/agents.py | 48 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 7e8016164..d2bf40811 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -211,6 +211,54 @@ def initialize_system_prompt(self): system_prompt = format_prompt_with_managed_agents_descriptions(system_prompt, self.managed_agents) return system_prompt + def prompt_push_to_hub( + self, + repo_id: str, + filename: Optional[str] = "prompt.yaml", + token: Optional[str] = None, + private: Optional[bool] = False, + metadata: Optional[Dict[str, Any]] = None, + ) -> Any: + """ + Push a prompt template to the Hugging Face Hub. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + filename (str, optional): Name of the file to save. Default is "prompt.yaml". Json and yaml extensions are supported. + token (str, optional): Authentication token for the Hugging Face API. + private (bool, optional): Whether to make the repository private. Default is False. + metadata (dict, optional): Metadata associated with the prompt template. + + Returns: + The saved prompt template object. + """ + try: + from prompt_templates import ChatPromptTemplate + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "To push a prompt to the hub, you need to have the 'prompt_templates' library installed. " + "You can install it with 'pip install prompt_templates'." + ) from e + system_prompt = self.initialize_system_prompt() + + user_prompt = self.task if self.task is not None else "No task provided" + + messages_template = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + prompt_template = ChatPromptTemplate(template=messages_template, metadata=metadata) + + try: + prompt_template.save_to_hub( + repo_id=repo_id, filename=filename, token=token, create_repo=True, private=private + ) + except Exception as e: + raise ValueError(f"An error occurred while pushing the template to the hub: {str(e)}") + + return prompt_template + def write_memory_to_messages( self, summary_mode: Optional[bool] = False,