diff --git a/gptcli/cli.py b/gptcli/cli.py index 62f10e5..b1d1034 100644 --- a/gptcli/cli.py +++ b/gptcli/cli.py @@ -24,6 +24,7 @@ TERMINAL_WELCOME = """ Hi! I'm here to help. Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C and Enter to clear the conversation, `:r` or Ctrl-R to re-generate the last response. +Type `:b X` where X is a message pair number to go back to that message pair. To enter multi-line mode, enter a backslash `\\` followed by a new line. Exit the multi-line mode by pressing ESC and then Enter (Meta+Enter). Try `:?` for help. @@ -101,6 +102,9 @@ def on_chat_rerun(self, success: bool): else: self.console.print("[bold]Nothing to re-run.[/bold]") + def on_chat_back(self, x: int): + self.console.print(f"[bold]Going back to message pair {x}.[/bold]") + def on_error(self, e: Exception): if isinstance(e, BadRequestError): self.console.print( diff --git a/gptcli/composite.py b/gptcli/composite.py index fde031f..dffc5d0 100644 --- a/gptcli/composite.py +++ b/gptcli/composite.py @@ -39,6 +39,10 @@ def on_chat_rerun(self, success: bool): for listener in self.listeners: listener.on_chat_rerun(success) + def on_chat_back(self, x: int): + for listener in self.listeners: + listener.on_chat_back(x) + def on_error(self, e: Exception): for listener in self.listeners: listener.on_error(e) diff --git a/gptcli/cost.py b/gptcli/cost.py index 7b527c1..540ca6b 100644 --- a/gptcli/cost.py +++ b/gptcli/cost.py @@ -30,6 +30,7 @@ def on_chat_response( model = self.assistant._param("model") num_tokens = usage.total_tokens cost = usage.cost + message_idx = len(messages) // 2 if cost is None: self.logger.error(f"Cannot get cost information for model {model}") @@ -40,7 +41,7 @@ def on_chat_response( self.logger.info(f"Message price (model: {model}): ${cost:.3f}") self.logger.info(f"Current spend: ${self.current_spend:.3f}") self.console.print( - f"Tokens: {num_tokens} | Price: ${cost:.3f} | Total: ${self.current_spend:.3f}", + f"Message pair #: {message_idx} | Tokens: {num_tokens} | Price: ${cost:.3f} | Total: ${self.current_spend:.3f}", justify="right", style="dim", ) diff --git a/gptcli/logging_utils.py b/gptcli/logging_utils.py index b5ac84c..700cf02 100644 --- a/gptcli/logging_utils.py +++ b/gptcli/logging_utils.py @@ -17,6 +17,9 @@ def on_chat_rerun(self, success: bool): if success: self.logger.info("Re-generating the last message.") + def on_chat_back(self, x: int): + self.logger.info(f"Going back to message pair {x}.") + def on_error(self, e: Exception): self.logger.exception(e) diff --git a/gptcli/session.py b/gptcli/session.py index 61f8ff7..0f99e58 100644 --- a/gptcli/session.py +++ b/gptcli/session.py @@ -1,3 +1,4 @@ +import re from abc import abstractmethod from gptcli.assistant import Assistant from gptcli.completion import ( @@ -30,6 +31,9 @@ def on_chat_clear(self): def on_chat_rerun(self, success: bool): pass + def on_chat_back(self, x: int): + pass + def on_error(self, error: Exception): pass @@ -63,12 +67,14 @@ def __init__(self, message: str): COMMAND_QUIT = (":quit", ":q") COMMAND_RERUN = (":rerun", ":r") COMMAND_HELP = (":help", ":h", ":?") +COMMAND_BACK_REGEX = r":b(?:ack)? (\d{1,3})" ALL_COMMANDS = [*COMMAND_CLEAR, *COMMAND_QUIT, *COMMAND_RERUN, *COMMAND_HELP] COMMANDS_HELP = """ Commands: - `:clear` / `:c` / Ctrl+C - Clear the conversation. - `:quit` / `:q` / Ctrl+D - Quit the program. - `:rerun` / `:r` / Ctrl+R - Re-run the last message. +- `:back X` / `:b X` - Go back to message X. Does not re-run assistant's response. - `:help` / `:h` / `:?` - Show this help message. """ @@ -144,9 +150,20 @@ def _add_user_message(self, user_input: str): self.listener.on_chat_message(user_message) self.user_prompts.append(user_message) - def _rollback_user_message(self): - self.messages = self.messages[:-1] - self.user_prompts = self.user_prompts[:-1] + def _back(self, x: int): + """ + Go back to user-assistant message pair x in the conversation. Following messages will be discarded. + """ + self._rollback_user_message(x) + self.listener.on_chat_back(x) + + def _rollback_user_message(self, x: int = None): + if x is None: + self.messages = self.messages[:-1] + self.user_prompts = self.user_prompts[:-1] + else: + self.messages = self.messages[:2*x + 1] + self.user_prompts = self.user_prompts[:2*x + 1] def _print_help(self): with self.listener.response_streamer() as stream: @@ -167,6 +184,10 @@ def process_input(self, user_input: str): elif user_input in COMMAND_HELP: self._print_help() return True + elif re.match(COMMAND_BACK_REGEX, user_input): + match = re.match(COMMAND_BACK_REGEX, user_input) + self._back(int(match.group(1))) + return True self._add_user_message(user_input) response_saved = self._respond()