|
| 1 | +from typing import Any, Optional |
| 2 | + |
| 3 | +from jupyterlab_chat.models import Message |
| 4 | +from litellm import acompletion |
| 5 | + |
| 6 | +from jupyter_ai_persona_manager import BasePersona, PersonaDefaults |
| 7 | +from jupyter_ai_persona_manager.persona_manager import SYSTEM_USERNAME |
| 8 | +from .prompt_template import ( |
| 9 | + JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, |
| 10 | + JupyternautSystemPromptArgs, |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +class JupyternautPersona(BasePersona): |
| 15 | + """ |
| 16 | + The Jupyternaut persona, the main persona provided by Jupyter AI. |
| 17 | + """ |
| 18 | + |
| 19 | + def __init__(self, *args, **kwargs): |
| 20 | + super().__init__(*args, **kwargs) |
| 21 | + |
| 22 | + @property |
| 23 | + def defaults(self): |
| 24 | + return PersonaDefaults( |
| 25 | + name="Jupyternaut", |
| 26 | + avatar_path="/api/jupyternaut/static/jupyternaut.svg", |
| 27 | + description="The standard agent provided by JupyterLab. Currently has no tools.", |
| 28 | + system_prompt="...", |
| 29 | + ) |
| 30 | + |
| 31 | + async def process_message(self, message: Message) -> None: |
| 32 | + if not hasattr(self, 'config_manager'): |
| 33 | + self.send_message( |
| 34 | + "Jupyternaut requires the `jupyter_ai_jupyternaut` server extension package.\n\n", |
| 35 | + "Please make sure to first install that package in your environment & restart the server." |
| 36 | + ) |
| 37 | + if not self.config_manager.chat_model: |
| 38 | + self.send_message( |
| 39 | + "No chat model is configured.\n\n" |
| 40 | + "You must set one first in the Jupyter AI settings, found in 'Settings > AI Settings' from the menu bar." |
| 41 | + ) |
| 42 | + return |
| 43 | + |
| 44 | + model_id = self.config_manager.chat_model |
| 45 | + model_args = self.config_manager.chat_model_args |
| 46 | + context_as_messages = self.get_context_as_messages(model_id, message) |
| 47 | + response_aiter = await acompletion( |
| 48 | + **model_args, |
| 49 | + model=model_id, |
| 50 | + messages=[ |
| 51 | + *context_as_messages, |
| 52 | + { |
| 53 | + "role": "user", |
| 54 | + "content": message.body, |
| 55 | + }, |
| 56 | + ], |
| 57 | + stream=True, |
| 58 | + ) |
| 59 | + |
| 60 | + await self.stream_message(response_aiter) |
| 61 | + |
| 62 | + def get_context_as_messages( |
| 63 | + self, model_id: str, message: Message |
| 64 | + ) -> list[dict[str, Any]]: |
| 65 | + """ |
| 66 | + Returns the current context, including attachments and recent messages, |
| 67 | + as a list of messages accepted by `litellm.acompletion()`. |
| 68 | + """ |
| 69 | + system_msg_args = JupyternautSystemPromptArgs( |
| 70 | + model_id=model_id, |
| 71 | + persona_name=self.name, |
| 72 | + context=self.process_attachments(message), |
| 73 | + ).model_dump() |
| 74 | + |
| 75 | + system_msg = { |
| 76 | + "role": "system", |
| 77 | + "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args), |
| 78 | + } |
| 79 | + |
| 80 | + context_as_messages = [system_msg, *self._get_history_as_messages()] |
| 81 | + return context_as_messages |
| 82 | + |
| 83 | + def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]: |
| 84 | + """ |
| 85 | + Returns the current history as a list of messages accepted by |
| 86 | + `litellm.acompletion()`. |
| 87 | + """ |
| 88 | + # TODO: consider bounding history based on message size (e.g. total |
| 89 | + # char/token count) instead of message count. |
| 90 | + all_messages = self.ychat.get_messages() |
| 91 | + |
| 92 | + # gather last k * 2 messages and return |
| 93 | + # we exclude the last message since that is the human message just |
| 94 | + # submitted by a user. |
| 95 | + start_idx = 0 if k is None else -2 * k - 1 |
| 96 | + recent_messages: list[Message] = all_messages[start_idx:-1] |
| 97 | + |
| 98 | + history: list[dict[str, Any]] = [] |
| 99 | + for msg in recent_messages: |
| 100 | + role = ( |
| 101 | + "assistant" |
| 102 | + if msg.sender.startswith("jupyter-ai-personas::") |
| 103 | + else "system" if msg.sender == SYSTEM_USERNAME else "user" |
| 104 | + ) |
| 105 | + history.append({"role": role, "content": msg.body}) |
| 106 | + |
| 107 | + return history |
0 commit comments