diff --git a/perplexity-cli b/perplexity-cli new file mode 100755 index 0000000..688ac98 --- /dev/null +++ b/perplexity-cli @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +# PYTHON_ARGCOMPLETE_OK +import sys + +from perplexity.cli import main + + +raise SystemExit(main(sys.argv)) diff --git a/perplexity/__init__.py b/perplexity/__init__.py index df97250..dc78bd4 100644 --- a/perplexity/__init__.py +++ b/perplexity/__init__.py @@ -2,4 +2,5 @@ from .utils import * from .labs import Labs -from .perplexity import Perplexity \ No newline at end of file +from .perplexity import Perplexity +from .stream import AnswerStreamParser diff --git a/perplexity/cli.py b/perplexity/cli.py new file mode 100644 index 0000000..96dceaf --- /dev/null +++ b/perplexity/cli.py @@ -0,0 +1,159 @@ +import argparse +import sys +from typing import Optional, TextIO + +import argcomplete + +from perplexity.config import configure_mail +from perplexity import AnswerStreamParser, Perplexity + + +class OutputWriter: + def __init__(self, raw: bool = False, stream: Optional[TextIO] = None) -> None: + self.raw = raw + self.stream = stream or sys.stdout + self._renderer = None + self._buffer = "" + self._started = False + + def write(self, text: str) -> None: + if not text: + return + + if self.raw: + self.stream.write(text) + self.stream.flush() + return + + self._buffer += text + last_para = self._buffer.rfind("\n\n") + if last_para >= 0: + to_render = self._buffer[:last_para + 2] + self._buffer = self._buffer[last_para + 2:] + self._render(to_render) + + def close(self) -> None: + if self._buffer: + self._render(self._buffer) + self._buffer = "" + if self._renderer is not None: + self._render_trailing() + self._renderer.tidyup() + self.stream.flush() + + def _render(self, text: str) -> None: + renderer = self._streamdown() + if not self._started: + self._started = True + renderer.render("\n") + if hasattr(renderer, "state"): + renderer.state.list_item_stack = [] + renderer.state.in_list = False + renderer.state.list_indent_text = 0 + renderer.render(text) + + def _render_trailing(self) -> None: + self._streamdown().render("\n") + + def _streamdown(self): + if self._renderer is None: + self._renderer = _load_streamdown() + return self._renderer + + +def _load_streamdown(): + import shutil + import streamdown + import streamdown.sdlib as sdlib + + sd = streamdown.Streamdown() + sd.setup() + _patch_streamdown(sd, sdlib) + return sd + + +def _patch_streamdown(sd, sdlib) -> None: + terminal_cols = shutil.get_terminal_size().columns + sd.state.WidthArg = min(terminal_cols, 100) + sd.width_calc() + + orig_emit_h = sdlib.emit_h + + def emit_h_left(level, text): + from streamdown.sdlib import line_format, text_wrap, BOLD, FG, FGRESET + + if level > 2: + return orig_emit_h(level, text) + text = line_format(text) + res = [] + for line in text_wrap(text): + if level == 1: + res.append(f"{sd.state.space_left()}\n{sd.state.space_left()}{BOLD[0]}{line}{BOLD[1]}\n") + else: + res.append(f"{sd.state.space_left()}\n{sd.state.space_left()}{BOLD[0]}{FG}{sd.Style.Bright}{line}{BOLD[1]}{FGRESET}") + return "\n".join(res) + + sdlib.emit_h = emit_h_left + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Perplexity CLI") + parser.add_argument("-a", "--account", metavar="EMAIL", help="account email for authenticated requests") + parser.add_argument("-r", "--raw", action="store_true", help="print parsed markdown without terminal rendering") + parser.add_argument("-s", "--sources", action="store_true", help="append sources") + parser.add_argument("-p", "--pro", action="store_true", help="use Pro search") + parser.add_argument("prompt", nargs="+", help="search prompt") + return parser + + +def configure(argv: list[str]) -> int: + config_parser = argparse.ArgumentParser( + prog=f"{argv[0]} config", + description="Configure perplexity-cli", + ) + config_subparsers = config_parser.add_subparsers(dest="config_command", required=True) + config_subparsers.add_parser("mail", help="configure IMAP mail login") + config_args = config_parser.parse_args(argv[2:]) + if config_args.config_command == "mail": + configure_mail() + return 0 + + +def run(args: argparse.Namespace) -> int: + perplexity = Perplexity(args.account) + writer = OutputWriter(raw=args.raw) + try: + answer = perplexity.search(" ".join(args.prompt), mode="copilot" if args.pro else "concise") + stream_parser = AnswerStreamParser() + for event in answer: + delta = stream_parser.feed(event) + writer.write(delta) + + if stream_parser.text: + writer.write("\n") + + if args.sources: + sources = stream_parser.format_sources(cited_only=stream_parser.has_citations()) + if sources: + writer.write(sources + "\n") + return 0 + finally: + try: + writer.close() + finally: + perplexity.close() + + +def main(argv: Optional[list[str]] = None) -> int: + argv = argv or sys.argv + try: + if len(argv) > 1 and argv[1] == "config": + return configure(argv) + + parser = build_parser() + argcomplete.autocomplete(parser) + return run(parser.parse_args(argv[1:])) + except KeyboardInterrupt: + sys.stderr.write("\n") + sys.stderr.flush() + return 130 diff --git a/perplexity/config.py b/perplexity/config.py new file mode 100644 index 0000000..942bddf --- /dev/null +++ b/perplexity/config.py @@ -0,0 +1,286 @@ +import re +from datetime import datetime, timezone +from email import message_from_bytes +from email.utils import parsedate_to_datetime +from getpass import getpass +from html import unescape +from imaplib import IMAP4, IMAP4_SSL +from json import JSONDecodeError, dumps, loads +from pathlib import Path +from re import DOTALL, IGNORECASE, search, sub +from sys import stderr +from time import sleep, time +from typing import Callable, Optional + + +def config_dir() -> Path: + path = Path.home() / ".cache" / "perplexity-cli" + path.mkdir(parents=True, exist_ok=True) + return path + + +def config_path() -> Path: + return config_dir() / "config.json" + + +def load_config() -> dict: + try: + return loads(config_path().read_text()) + except (OSError, JSONDecodeError): + return {} + + +def save_config(config: dict) -> None: + path = config_path() + path.write_text(dumps(config, indent=2) + "\n") + try: + path.chmod(0o600) + except OSError: + pass + + +def mail_config_for(account: str) -> Optional[dict]: + accounts = load_config().get("mail", {}).get("accounts", {}) + if account in accounts: + return accounts[account] + for configured_account, mail_config in accounts.items(): + if configured_account.lower() == account.lower(): + return mail_config + return None + +def loop_input(prompt: str, default: Optional[str] = None, validator: Optional[Callable[[str], bool]] = None) -> str: + while True: + value = input(prompt).strip() + if value: + if validator and not validator(value): + continue + return value + if default is not None: + return default + +def loop_password(prompt: str, default: Optional[str] = None) -> str: + while True: + value = getpass(prompt) + if value: + return value + if default is not None: + return default + +def is_valid_port(value: str) -> bool: + if not value.isdigit(): + return False + port = int(value) + is_valid = 1 <= port <= 65535 + if not is_valid: + stderr.write("ERROR: Port number must be between 1 and 65535.\n\n") + return is_valid + +def is_valid_email(value: str) -> bool: + is_valid = True + if " " in value: + is_valid = False + + if is_valid: + m = re.fullmatch(r"([A-Za-z0-9.!#$%&'*+/=?^_`{|}~-]+)@([A-Za-z0-9-]+(?:\.[A-Za-z0-9-]+)+)", value) + if not m: + is_valid = False + else: + domain = m.group(2) + labels = domain.split(".") + is_valid = all(label and not label.startswith("-") and not label.endswith("-") for label in labels) + + if not is_valid: + stderr.write("ERROR: Invalid email address.\n\n") + return is_valid + +def is_valid_yes_no(value: str) -> bool: + is_valid = bool(re.fullmatch(r"[yY](es)?|[nN]o?", value)) + if not is_valid: + stderr.write("ERROR: Enter yes or no.\n\n") + return is_valid + +def parse_yes_no(value: str) -> bool: + return bool(re.fullmatch(r"[yY](es)?", value)) + +def configure_mail() -> None: + stderr.write("Configure the IMAP mail account which receives Perplexity Sign-in emails...\n\n") + accounts = load_config().get("mail", {}).get("accounts", {}) + if accounts: + stderr.write("Configured accounts:\n") + for configured_account in sorted(accounts): + stderr.write(f" - {configured_account}\n") + stderr.write("\n") + account = loop_input("Perplexity Account: ", validator=is_valid_email) + existing_mail_config = mail_config_for(account) or {} + stderr.write("\nYou can forward the Perplexity Sign-in emails to a different address,\nif so, enter it here. Otherwise, just press enter.\n\n") + address_default = existing_mail_config.get("address", account) + address = loop_input(f"Email Address [{address_default}]: ", address_default, is_valid_email) + stderr.write("\nEnter the IMAP login details.\n") + username_default = existing_mail_config.get("username") + username_prompt = f"Username [{username_default}]: " if username_default else "Username: " + username = loop_input(username_prompt, username_default) + password_default = existing_mail_config.get("password") + password_prompt = "Password [configured]: " if password_default else "Password: " + password = loop_password(password_prompt, password_default) + imap_hostname_default = existing_mail_config.get("imap_hostname") + imap_hostname_prompt = f"IMAP Server [{imap_hostname_default}]: " if imap_hostname_default else "IMAP Server: " + imap_hostname = loop_input(imap_hostname_prompt, imap_hostname_default) + port_default = str(existing_mail_config.get("port", 993)) + port_text = loop_input(f"Port [{port_default}]: ", port_default, is_valid_port) + port = int(port_text) + folder_default = existing_mail_config.get("folder", "INBOX") + folder = loop_input(f"Folder [{folder_default}]: ", folder_default) + delete_default = existing_mail_config.get("delete_signin_messages", False if existing_mail_config else True) + delete_prompt = "Delete Sign-in Messages [Y/n]? " if delete_default else "Delete Sign-in Messages [y/N]? " + delete_signin_messages_text = loop_input(delete_prompt, "yes" if delete_default else "no", is_valid_yes_no) + delete_signin_messages = parse_yes_no(delete_signin_messages_text) + + mail_config = { + "address": address, + "username": username, + "password": password, + "imap_hostname": imap_hostname, + "port": port, + "folder": folder, + "delete_signin_messages": delete_signin_messages, + } + stderr.write("\nValidating IMAP login details...\n") + validate_mail_config(mail_config) + stderr.write("IMAP login details validated.\n") + + config = load_config() + config.setdefault("mail", {}).setdefault("accounts", {})[account] = mail_config + save_config(config) + stderr.write(f"\nMail login configured for {account}.\n") + + +def validate_mail_config(mail_config: dict) -> None: + port = int(mail_config["port"]) + if port == 993: + mailbox = IMAP4_SSL(mail_config["imap_hostname"], port) + else: + mailbox = IMAP4(mail_config["imap_hostname"], port) + mailbox.starttls() + try: + status, _ = mailbox.login(mail_config["username"], mail_config["password"]) + if status != "OK": + raise RuntimeError("IMAP login failed") + folder = mail_config.get("folder", "INBOX") + status, _ = mailbox.select(folder) + if status != "OK": + raise RuntimeError(f"IMAP folder selection failed: {folder}") + finally: + try: + mailbox.logout() + except OSError: + pass + + +def extract_perplexity_login_url(raw_message: bytes, account: str, now: Optional[datetime] = None) -> Optional[str]: + msg = message_from_bytes(raw_message) + subject = str(msg.get("Subject", "")) + sender = str(msg.get("From", "")) + recipient = str(msg.get("To", "")) + if subject != "Sign in to Perplexity": + return None + if "perplexity" not in sender.lower(): + return None + if account.lower() not in recipient.lower(): + return None + + try: + date = parsedate_to_datetime(msg.get("Date")) + except (TypeError, ValueError): + return None + if date.tzinfo is None: + date = date.replace(tzinfo=timezone.utc) + current = now or datetime.now(timezone.utc) + if (current - date).total_seconds() > 60: + return None + + for body in _message_bodies(msg): + url = _extract_url_from_body(body) + if url: + return url + return None + + +def retrieve_login_url_from_mail(account: str, mail_config: dict, timeout: float = 60) -> str: + deadline = time() + timeout + last_error = None + while time() < deadline: + try: + url = _check_mailbox(account, mail_config) + if url: + return url + except Exception as exc: + last_error = exc + sleep(5) + if last_error: + raise RuntimeError(f"mail lookup failed: {last_error}") + raise RuntimeError("mail lookup timed out") + + +def _check_mailbox(account: str, mail_config: dict) -> Optional[str]: + port = int(mail_config["port"]) + if port == 993: + mailbox = IMAP4_SSL(mail_config["imap_hostname"], port) + else: + mailbox = IMAP4(mail_config["imap_hostname"], port) + mailbox.starttls() + try: + mailbox.login(mail_config["username"], mail_config["password"]) + mailbox.select(mail_config.get("folder", "INBOX")) + status, data = mailbox.search(None, 'FROM "perplexity" SUBJECT "Sign in to Perplexity"') + if status != "OK" or not data or not data[0]: + return None + for message_id in reversed(data[0].split()): + status, message_data = mailbox.fetch(message_id, "(RFC822)") + if status != "OK": + continue + for part in message_data: + if not isinstance(part, tuple): + continue + url = extract_perplexity_login_url(part[1], account) + if url: + if mail_config.get("delete_signin_messages", False): + mailbox.store(message_id, "+FLAGS", "\\Deleted") + mailbox.expunge() + return url + return None + finally: + try: + mailbox.logout() + except OSError: + pass + + +def _message_bodies(msg) -> list[str]: + bodies = [] + parts = msg.walk() if msg.is_multipart() else [msg] + for part in parts: + if part.get_content_maintype() == "multipart": + continue + payload = part.get_payload(decode=True) + if payload is None: + continue + charset = part.get_content_charset() or "utf-8" + bodies.append(payload.decode(charset, errors="replace")) + return bodies + + +def _extract_url_from_body(body: str) -> Optional[str]: + text = unescape(body) + text = sub(r"=\r?\n", "", text) + text = text.replace("=3D", "=") + match = search( + r"https://www\.perplexity\.ai/api/auth/callback/email\?[^<>\s\"']+", + text, + IGNORECASE | DOTALL, + ) + if not match: + return None + url = match.group(0) + url = sub(r"\s+", "", url) + url = url.replace("&", "&") + return url diff --git a/perplexity/perplexity.py b/perplexity/perplexity.py index f90d315..4f8807e 100644 --- a/perplexity/perplexity.py +++ b/perplexity/perplexity.py @@ -1,32 +1,34 @@ from typing import Iterable, Dict +from urllib.parse import quote from os import listdir +from pathlib import Path +from re import sub from uuid import uuid4 from time import sleep, time from threading import Thread -from json import loads, dumps +from json import loads, dumps, JSONDecodeError from random import getrandbits from websocket import WebSocketApp from requests import Session, get, post +from requests.exceptions import RequestException +from .config import config_dir, mail_config_for, retrieve_login_url_from_mail class Perplexity: def __init__(self, email: str = None) -> None: - self.session: Session = Session() self.user_agent: dict = { "User-Agent": "Ask/2.9.1/2406 (iOS; iPhone; Version 17.1) isiOSOnMac/false", "X-Client-Name": "Perplexity-iOS", "X-App-ApiClient": "ios" } - self.session.headers.update(self.user_agent) + self.email: str = email + self._reset_session() - if email and ".perplexity_session" in listdir(): - self._recover_session(email) - else: + recovered_session = False + if email and self._token_path(email).exists(): + recovered_session = self._recover_session(email) + if not recovered_session: self._init_session_without_login() if email: self._login(email) - self.email: str = email - self.t: str = self._get_t() - self.sid: str = self._get_sid() - self.n: int = 1 self.base: int = 420 self.queue: list = [] @@ -35,7 +37,16 @@ def __init__(self, email: str = None) -> None: self.backend_uuid: str = None # unused because we can't yet follow-up questions self.frontend_session_id: str = str(uuid4()) - assert self._ask_anonymous_user(), "failed to ask anonymous user" + if not self._bootstrap_socket_session(): + if email and recovered_session: + self._reset_session() + self._init_session_without_login() + self._login(email) + if not self._bootstrap_socket_session(): + raise RuntimeError("failed to initialize websocket session after re-login") + else: + raise RuntimeError("failed to initialize websocket session") + self.ws: WebSocketApp = self._init_websocket() self.ws_thread: Thread = Thread(target=self.ws.run_forever).start() self._auth_session() @@ -43,28 +54,56 @@ def __init__(self, email: str = None) -> None: while not (self.ws.sock and self.ws.sock.connected): sleep(0.01) - def _recover_session(self, email: str) -> None: - with open(".perplexity_session", "r") as f: - perplexity_session: dict = loads(f.read()) + def _token_path(self, email: str) -> Path: + safe = sub(r"[^a-zA-Z0-9-]", "_", email) + return config_dir() / f"{safe}.token" + + def _reset_session(self) -> None: + self.session: Session = Session() + self.session.headers.update(self.user_agent) + + def _recover_session(self, email: str) -> bool: + try: + cookies = loads(self._token_path(email).read_text()) + except (OSError, JSONDecodeError): + return False + self.session.cookies.update(cookies) + return True + + def _bootstrap_socket_session(self) -> bool: + try: + self.t = self._get_t() + self.sid = self._get_sid() + return self._ask_anonymous_user() + except (IndexError, KeyError, JSONDecodeError, RequestException): + return False + - if email in perplexity_session: - self.session.cookies.update(perplexity_session[email]) - else: - self._login(email, perplexity_session) - def _login(self, email: str, ps: dict = None) -> None: self.session.post(url="https://www.perplexity.ai/api/auth/signin-email", data={"email": email}) - email_link: str = str(input("paste the link you received by email: ")) + import sys + email_link = None + mail_config = mail_config_for(email) + if mail_config: + sys.stderr.write("Retrieving Perplexity login token from email...\n") + sys.stderr.flush() + try: + email_link = retrieve_login_url_from_mail(email, mail_config) + except Exception as exc: + sys.stderr.write(f"Failed to retrieve token from email: {exc}\n") + sys.stderr.flush() + if not email_link: + sys.stderr.write("Token (or link) received via email: ") + sys.stderr.flush() + email_input: str = sys.stdin.readline().strip() + if email_input.startswith("http"): + email_link = email_input + else: + email_link = f"https://www.perplexity.ai/api/auth/callback/email?callbackUrl=defaultMobileSignIn&email={quote(email)}&token={email_input}" self.session.get(email_link) - if ps: - ps[email] = self.session.cookies.get_dict() - else: - ps = {email: self.session.cookies.get_dict()} - - with open(".perplexity_session", "w") as f: - f.write(dumps(ps)) + self._token_path(email).write_text(dumps(self.session.cookies.get_dict())) def _init_session_without_login(self) -> None: self.session.get(url=f"https://www.perplexity.ai/search/{str(uuid4())}") @@ -130,11 +169,12 @@ def on_message(ws: WebSocketApp, message: str) -> None: if message.startswith("42"): message : list = loads(message[2:]) content: dict = message[1] - if "mode" in content and content["mode"] == "copilot": - content["copilot_answer"] = loads(content["text"]) - elif "mode" in content: - content.update(loads(content["text"])) - content.pop("text") + if "text" in content: + if "mode" in content and content["mode"] == "copilot": + content["copilot_answer"] = loads(content["text"]) + elif "mode" in content: + content.update(loads(content["text"])) + content.pop("text", None) if (not ("final" in content and content["final"])) or ("status" in content and content["status"] == "completed"): self.queue.append(content) if message[0] == "query_answered": @@ -303,10 +343,4 @@ def close(self) -> None: self.ws.close() if self.email: - with open(".perplexity_session", "r") as f: - perplexity_session: dict = loads(f.read()) - - perplexity_session[self.email] = self.session.cookies.get_dict() - - with open(".perplexity_session", "w") as f: - f.write(dumps(perplexity_session)) \ No newline at end of file + self._token_path(self.email).write_text(dumps(self.session.cookies.get_dict())) diff --git a/perplexity/stream.py b/perplexity/stream.py new file mode 100644 index 0000000..e00620a --- /dev/null +++ b/perplexity/stream.py @@ -0,0 +1,314 @@ +from re import Match, search, sub +from typing import Any, Dict, Iterable, Optional + + +class AnswerStreamParser: + """Extract live answer text deltas from Perplexity websocket events.""" + + def __init__(self) -> None: + self.text = "" + self.raw_text = "" + self._chunk_count = 0 + self._source_urls: set[str] = set() + self._annotations: list[dict[str, Any]] = [] + self.sources: list[dict[str, str]] = [] + + def feed(self, event: Dict[str, Any]) -> str: + self._collect_sources(event) + state = self._event_state(event) + if state is None: + return "" + raw_text, chunk_count = state + text = self._stable_answer_text(self._format_answer_text(raw_text)) + self.raw_text = raw_text + + if text.startswith(self.text): + delta = text[len(self.text):] + self.text = text + self._chunk_count = chunk_count + return delta + + if self.text.startswith(text): + return "" + + prefix_len = self._common_prefix_len(self.text, text) + if prefix_len == len(self.text): + delta = text[prefix_len:] + self.text = text + self._chunk_count = chunk_count + return delta + + return "" + + def parse(self, events: Iterable[Dict[str, Any]]) -> Iterable[str]: + for event in events: + delta = self.feed(event) + if delta: + yield delta + + def format_answer(self, citations: bool = False) -> str: + if not citations: + return self.text + + markers = self._citation_markers() + if not markers: + return self.text + + text = self.text + for offset, marker in sorted(markers.items(), reverse=True): + if 0 <= offset <= len(text): + text = text[:offset] + marker + text[offset:] + return text + + def has_citations(self) -> bool: + return bool(self._citation_markers()) + + def format_sources(self, cited_only: bool = False) -> str: + sources = self._cited_sources() if cited_only else self.sources + if not sources: + return "" + + lines = ["", "---", "", "## Sources", ""] + for default_index, source in enumerate(sources, start=1): + index = source.get("number", str(default_index)) + name = source.get("name") or source.get("url") or "Source" + url = source.get("url", "") + if url: + lines.append(f"- **[{index}]** [{name}]({url})") + else: + lines.append(f"- **[{index}]** {name}") + return "\n".join(lines) + + def _collect_sources(self, event: Dict[str, Any]) -> None: + blocks = event.get("blocks") + if not isinstance(blocks, list): + return + + for block in blocks: + web_result_block = block.get("web_result_block") + if not isinstance(web_result_block, dict): + continue + web_results = web_result_block.get("web_results") + if not isinstance(web_results, list): + continue + for result in web_results: + if not isinstance(result, dict): + continue + url = result.get("url") + if not isinstance(url, str) or not url or url in self._source_urls: + continue + name = result.get("name") + self._source_urls.add(url) + self.sources.append({ + "name": name if isinstance(name, str) else url, + "url": url, + }) + + def _event_state(self, event: Dict[str, Any]) -> Optional[tuple[str, int]]: + blocks = event.get("blocks") + if not isinstance(blocks, list): + return None + + for usage in ("ask_text_0_markdown", "ask_text"): + for block in blocks: + if block.get("intended_usage") == usage: + markdown_block = block.get("markdown_block") + self._collect_annotations(markdown_block) + state = self._markdown_block_state(markdown_block) + if state is not None: + return state + + for block in blocks: + if isinstance(block.get("markdown_block"), dict): + markdown_block = block["markdown_block"] + self._collect_annotations(markdown_block) + state = self._markdown_block_state(markdown_block) + if state is not None: + return state + + return None + + def _markdown_block_state(self, block: Any) -> Optional[tuple[str, int]]: + if not isinstance(block, dict): + return None + + chunks = block.get("chunks") + answer = block.get("answer") + if isinstance(answer, str): + chunk_count = len(chunks) if isinstance(chunks, list) else self._chunk_count + return answer, chunk_count + + if not isinstance(chunks, list): + return None + + chunk_text = "".join(chunk for chunk in chunks if isinstance(chunk, str)) + offset = block.get("chunk_starting_offset", 0) + if isinstance(offset, int) and offset == self._chunk_count: + return self.raw_text + chunk_text, offset + len(chunks) + + if offset == 0: + return chunk_text, len(chunks) + + return None + + def _common_prefix_len(self, left: str, right: str) -> int: + length = min(len(left), len(right)) + for index in range(length): + if left[index] != right[index]: + return index + return length + + def _format_answer_text(self, text: str) -> str: + lines = text.splitlines(keepends=True) + in_fence = False + formatted_lines = [] + + for line in lines: + if line.lstrip().startswith("```"): + in_fence = not in_fence + formatted_lines.append(line) + elif in_fence: + formatted_lines.append(line) + else: + formatted_lines.append(self._format_inline_citations(line)) + + return "".join(formatted_lines) + + def _format_inline_citations(self, text: str) -> str: + parts = text.split("`") + for index in range(0, len(parts), 2): + parts[index] = self._space_citations(parts[index]) + return "`".join(parts) + + def _space_citations(self, text: str) -> str: + def format_marker(match: Match[str]) -> str: + formatted = f"`{match.group(0)}`" + if match.start() == 0: + return formatted + return " " + formatted + + return sub(r"(? str: + incomplete_marker = search(r"(? None: + if not isinstance(block, dict): + return + + annotations = block.get("inline_token_annotations") + if not isinstance(annotations, list): + return + + self._annotations = [ + annotation for annotation in annotations + if isinstance(annotation, dict) + ] + + def _citation_markers(self) -> dict[int, str]: + markers: dict[int, list[int]] = {} + for annotation in self._annotations: + end = self._annotation_end(annotation) + if end is None: + continue + numbers = self._annotation_source_numbers(annotation) + if not numbers: + continue + markers.setdefault(end, []) + for number in numbers: + if number not in markers[end]: + markers[end].append(number) + + return { + offset: "".join(f"`[{number}]`" for number in sorted(numbers)) + for offset, numbers in markers.items() + } + + def _annotation_end(self, annotation: dict[str, Any]) -> Optional[int]: + for key in ("end", "end_index", "end_offset", "stop", "stop_index"): + value = annotation.get(key) + if isinstance(value, int): + return value + + span = annotation.get("span") or annotation.get("text_span") or annotation.get("token_span") + if isinstance(span, dict): + for key in ("end", "end_index", "end_offset", "stop", "stop_index"): + value = span.get(key) + if isinstance(value, int): + return value + + return None + + def _annotation_source_numbers(self, annotation: dict[str, Any]) -> list[int]: + values = [] + for key in ( + "source_index", "source_indices", "source_ids", "source_id", + "citation_index", "citation_indices", "citation_ids", "citation_id", + "web_result_index", "web_result_indices", + ): + if key in annotation: + values.extend(self._flatten(annotation[key])) + + for key in ("url", "source_url", "citation_url"): + value = annotation.get(key) + if isinstance(value, str): + number = self._source_number_for_url(value) + if number: + values.append(number) + + numbers: list[int] = [] + for value in values: + number = self._normalize_source_number(value) + if number and number not in numbers: + numbers.append(number) + return numbers + + def _normalize_source_number(self, value: Any) -> Optional[int]: + if isinstance(value, str) and value.isdigit(): + value = int(value) + if not isinstance(value, int): + return None + if 0 <= value < len(self.sources): + return value + 1 + if 1 <= value <= len(self.sources): + return value + return None + + def _source_number_for_url(self, url: str) -> Optional[int]: + for index, source in enumerate(self.sources, start=1): + if source.get("url") == url: + return index + return None + + def _flatten(self, value: Any) -> list[Any]: + if isinstance(value, list): + values: list[Any] = [] + for item in value: + values.extend(self._flatten(item)) + return values + if isinstance(value, dict): + values: list[Any] = [] + for key in ( + "index", "indices", "id", "ids", "source_index", + "source_indices", "citation_index", "citation_indices", + "url", "source_url", "citation_url", + ): + if key in value: + values.extend(self._flatten(value[key])) + return values + return [value] + + def _cited_sources(self) -> list[dict[str, str]]: + numbers = sorted({ + number + for annotation in self._annotations + for number in self._annotation_source_numbers(annotation) + }) + return [ + {"number": str(number), **self.sources[number - 1]} + for number in numbers + if 1 <= number <= len(self.sources) + ] diff --git a/setup.py b/setup.py index 2ce6a41..c2c2e6e 100644 --- a/setup.py +++ b/setup.py @@ -11,5 +11,5 @@ long_description_content_type="text/markdown", url="https://github.com/nathanrchn/perplexityai", packages=find_packages(), - requires=["requests", "websocket-client"] -) \ No newline at end of file + install_requires=["requests", "websocket-client", "streamdown>=0.36.0"] +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..d9fd297 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,118 @@ +from io import StringIO +import unittest +from unittest.mock import patch + +from perplexity.cli import OutputWriter, build_parser, run + + +class FakeRenderer: + instances = [] + + def __init__(self): + self.rendered = [] + self.closed = False + FakeRenderer.instances.append(self) + + def render(self, text): + self.rendered.append(text) + + def tidyup(self): + self.closed = True + + +class FakePerplexity: + instances = [] + + def __init__(self, account): + self.account = account + self.closed = False + FakePerplexity.instances.append(self) + + def search(self, prompt, mode): + self.prompt = prompt + self.mode = mode + return [ + { + "blocks": [ + { + "intended_usage": "ask_text", + "markdown_block": {"answer": "# Hello"}, + } + ] + }, + { + "blocks": [ + { + "intended_usage": "ask_text", + "markdown_block": {"answer": "# Hello\n\nWorld"}, + } + ] + }, + ] + + def close(self): + self.closed = True + + +class OutputWriterTest(unittest.TestCase): + def setUp(self): + FakeRenderer.instances = [] + + def test_raw_writes_plain_text_without_streamdown(self): + stream = StringIO() + writer = OutputWriter(raw=True, stream=stream) + + with patch("perplexity.cli._load_streamdown", side_effect=AssertionError("unexpected streamdown")): + writer.write("# Hello") + writer.write("\n") + writer.close() + + self.assertEqual(stream.getvalue(), "# Hello\n") + self.assertEqual(FakeRenderer.instances, []) + + def test_rendered_output_uses_streamdown_and_tidyup(self): + writer = OutputWriter(raw=False, stream=StringIO()) + + with patch("perplexity.cli._load_streamdown", side_effect=FakeRenderer): + writer.write("# Hello") + writer.write("\n") + writer.close() + + self.assertEqual(FakeRenderer.instances[0].rendered, ["\n", "# Hello\n", "\n"]) + self.assertTrue(FakeRenderer.instances[0].closed) + + +class CliRunTest(unittest.TestCase): + def setUp(self): + FakePerplexity.instances = [] + FakeRenderer.instances = [] + + def test_run_streams_parsed_output_to_streamdown_by_default(self): + args = build_parser().parse_args(["-a", "me@example.com", "hello", "world"]) + + with patch("perplexity.cli.Perplexity", FakePerplexity): + with patch("perplexity.cli._load_streamdown", side_effect=FakeRenderer): + self.assertEqual(run(args), 0) + + self.assertEqual(FakePerplexity.instances[0].account, "me@example.com") + self.assertEqual(FakePerplexity.instances[0].prompt, "hello world") + self.assertEqual(FakePerplexity.instances[0].mode, "concise") + self.assertEqual(FakeRenderer.instances[0].rendered, ["\n", "# Hello\n\n", "World\n", "\n"]) + self.assertTrue(FakeRenderer.instances[0].closed) + self.assertTrue(FakePerplexity.instances[0].closed) + + def test_run_raw_prints_parsed_markdown_without_event_dump(self): + args = build_parser().parse_args(["--raw", "--pro", "hello"]) + stream = StringIO() + + with patch("perplexity.cli.Perplexity", FakePerplexity): + with patch("perplexity.cli.sys.stdout", stream): + self.assertEqual(run(args), 0) + + self.assertEqual(stream.getvalue(), "# Hello\n\nWorld\n") + self.assertEqual(FakePerplexity.instances[0].mode, "copilot") + self.assertEqual(FakeRenderer.instances, []) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..b0a5ad8 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,427 @@ +from datetime import datetime, timezone +from email.message import EmailMessage +from io import StringIO +from pathlib import Path +from tempfile import TemporaryDirectory +import unittest +from unittest.mock import patch + +from perplexity.config import ( + _check_mailbox, + configure_mail, + extract_perplexity_login_url, + is_valid_yes_no, + load_config, + mail_config_for, + parse_yes_no, + save_config, + validate_mail_config, +) + + +class FakeMailbox: + instances = [] + + def __init__(self, hostname, port): + self.hostname = hostname + self.port = port + self.started_tls = False + self.logged_out = False + FakeMailbox.instances.append(self) + + def starttls(self): + self.started_tls = True + + def login(self, username, password): + self.username = username + self.password = password + return ("OK", []) + + def select(self, mailbox): + self.selected_mailbox = mailbox + return ("OK", []) + + def logout(self): + self.logged_out = True + + +class FailingMailbox(FakeMailbox): + def login(self, username, password): + return ("NO", []) + + +class EmptyMailbox(FakeMailbox): + def search(self, charset, criterion): + self.search_charset = charset + self.search_criterion = criterion + return ("OK", [b""]) + + +class LoginMessageMailbox(FakeMailbox): + def search(self, charset, criterion): + return ("OK", [b"123"]) + + def fetch(self, message_id, spec): + self.fetched_message_id = message_id + msg = EmailMessage() + msg["Date"] = datetime.now(timezone.utc).strftime("%a, %d %b %Y %H:%M:%S %z") + msg["From"] = "Perplexity " + msg["To"] = "login@example.com" + msg["Subject"] = "Sign in to Perplexity" + msg.set_content( + "https://www.perplexity.ai/api/auth/callback/email?" + "callbackUrl=defaultMobileSignIn&email=login%40example.com&token=543116" + ) + return ("OK", [(message_id, msg.as_bytes())]) + + def store(self, message_id, command, flags): + self.stored_message_id = message_id + self.store_command = command + self.store_flags = flags + return ("OK", []) + + def expunge(self): + self.expunged = True + return ("OK", []) + + +class MailLoginExtractionTest(unittest.TestCase): + def _message(self, body, date=None): + msg = EmailMessage() + msg["Date"] = (date or datetime(2026, 5, 6, 6, 3, 8, tzinfo=timezone.utc)).strftime( + "%a, %d %b %Y %H:%M:%S %z" + ) + msg["From"] = "Perplexity " + msg["To"] = "d.grieser@mittwald.de" + msg["Subject"] = "Sign in to Perplexity" + msg.set_content(body, cte="quoted-printable") + return msg.as_bytes() + + def test_extracts_wrapped_quoted_printable_url(self): + body = ( + "https://www.perplexity.ai/api/auth/callback/email?=\n" + "callbackUrl=3DdefaultMobileSignIn&email=3Dd.grieser%40mittwald.=\n" + "de&token=3D543116" + ) + url = extract_perplexity_login_url( + self._message(body), + "d.grieser@mittwald.de", + now=datetime(2026, 5, 6, 6, 3, 30, tzinfo=timezone.utc), + ) + self.assertEqual( + url, + "https://www.perplexity.ai/api/auth/callback/email?callbackUrl=defaultMobileSignIn&email=d.grieser%40mittwald.de&token=543116", + ) + + def test_extracts_html_entity_url(self): + body = ( + '' + ) + url = extract_perplexity_login_url( + self._message(body), + "d.grieser@mittwald.de", + now=datetime(2026, 5, 6, 6, 3, 30, tzinfo=timezone.utc), + ) + self.assertEqual( + url, + "https://www.perplexity.ai/api/auth/callback/email?callbackUrl=defaultMobileSignIn&email=d.grieser%40mittwald.de&token=543116", + ) + + def test_rejects_stale_message(self): + raw = self._message( + "https://www.perplexity.ai/api/auth/callback/email?callbackUrl=defaultMobileSignIn&email=d.grieser%40mittwald.de&token=543116", + date=datetime(2026, 5, 6, 6, 1, 0, tzinfo=timezone.utc), + ) + url = extract_perplexity_login_url( + raw, + "d.grieser@mittwald.de", + now=datetime(2026, 5, 6, 6, 3, 0, tzinfo=timezone.utc), + ) + self.assertIsNone(url) + + +class ConfigStorageTest(unittest.TestCase): + def test_saves_and_loads_mail_account(self): + with TemporaryDirectory() as tmp: + with patch("perplexity.config.Path.home", return_value=Path(tmp)): + save_config( + { + "mail": { + "accounts": { + "login@example.com": { + "address": "mailbox@example.com", + "username": "imap-user", + "password": "secret", + "imap_hostname": "imap.example.com", + "port": 993, + "folder": "Archive/Perplexity", + } + } + } + } + ) + + self.assertEqual(load_config()["mail"]["accounts"]["login@example.com"]["port"], 993) + self.assertEqual( + load_config()["mail"]["accounts"]["login@example.com"]["folder"], + "Archive/Perplexity", + ) + self.assertEqual(mail_config_for("login@example.com")["address"], "mailbox@example.com") + self.assertEqual(mail_config_for("LOGIN@example.com")["username"], "imap-user") + + +class MailConfigValidationTest(unittest.TestCase): + def setUp(self): + FakeMailbox.instances = [] + + def test_validates_ssl_mail_config(self): + with patch("perplexity.config.IMAP4_SSL", FakeMailbox): + validate_mail_config( + { + "username": "imap-user", + "password": "secret", + "imap_hostname": "imap.example.com", + "port": 993, + } + ) + + mailbox = FakeMailbox.instances[0] + self.assertEqual(mailbox.hostname, "imap.example.com") + self.assertEqual(mailbox.port, 993) + self.assertEqual(mailbox.username, "imap-user") + self.assertEqual(mailbox.password, "secret") + self.assertEqual(mailbox.selected_mailbox, "INBOX") + self.assertFalse(mailbox.started_tls) + self.assertTrue(mailbox.logged_out) + + def test_validates_starttls_mail_config(self): + with patch("perplexity.config.IMAP4", FakeMailbox): + validate_mail_config( + { + "username": "imap-user", + "password": "secret", + "imap_hostname": "imap.example.com", + "port": 143, + } + ) + + self.assertTrue(FakeMailbox.instances[0].started_tls) + + def test_validates_configured_folder(self): + with patch("perplexity.config.IMAP4_SSL", FakeMailbox): + validate_mail_config( + { + "username": "imap-user", + "password": "secret", + "imap_hostname": "imap.example.com", + "port": 993, + "folder": "Archive/Perplexity", + } + ) + + self.assertEqual(FakeMailbox.instances[0].selected_mailbox, "Archive/Perplexity") + + def test_accepts_yes_no_delete_message_values(self): + for value in ["y", "Y", "yes", "Yes"]: + self.assertTrue(is_valid_yes_no(value)) + self.assertTrue(parse_yes_no(value)) + for value in ["n", "N", "no", "No"]: + self.assertTrue(is_valid_yes_no(value)) + self.assertFalse(parse_yes_no(value)) + + with patch("perplexity.config.stderr", StringIO()): + self.assertFalse(is_valid_yes_no("sure")) + + def test_configure_mail_does_not_save_invalid_mail_config(self): + answers = iter(["login@example.com", "", "imap-user", "imap.example.com", "993", "INBOX", ""]) + + with TemporaryDirectory() as tmp: + with patch("perplexity.config.Path.home", return_value=Path(tmp)): + with patch("perplexity.config.input", side_effect=lambda _prompt: next(answers)): + with patch("perplexity.config.getpass", return_value="secret"): + with patch("perplexity.config.IMAP4_SSL", FailingMailbox): + with patch("perplexity.config.stderr", StringIO()): + with self.assertRaises(RuntimeError): + configure_mail() + + self.assertEqual(load_config(), {}) + + def test_configure_mail_saves_default_folder(self): + answers = iter(["login@example.com", "", "imap-user", "imap.example.com", "993", "", ""]) + + with TemporaryDirectory() as tmp: + with patch("perplexity.config.Path.home", return_value=Path(tmp)): + with patch("perplexity.config.input", side_effect=lambda _prompt: next(answers)): + with patch("perplexity.config.getpass", return_value="secret"): + with patch("perplexity.config.IMAP4_SSL", FakeMailbox): + with patch("perplexity.config.stderr", StringIO()): + configure_mail() + + mail_config = load_config()["mail"]["accounts"]["login@example.com"] + self.assertEqual(mail_config["folder"], "INBOX") + self.assertTrue(mail_config["delete_signin_messages"]) + self.assertEqual(FakeMailbox.instances[0].selected_mailbox, "INBOX") + + def test_configure_mail_saves_no_delete_signin_messages(self): + answers = iter(["login@example.com", "", "imap-user", "imap.example.com", "993", "", "No"]) + + with TemporaryDirectory() as tmp: + with patch("perplexity.config.Path.home", return_value=Path(tmp)): + with patch("perplexity.config.input", side_effect=lambda _prompt: next(answers)): + with patch("perplexity.config.getpass", return_value="secret"): + with patch("perplexity.config.IMAP4_SSL", FakeMailbox): + with patch("perplexity.config.stderr", StringIO()): + configure_mail() + + mail_config = load_config()["mail"]["accounts"]["login@example.com"] + self.assertFalse(mail_config["delete_signin_messages"]) + + def test_configure_mail_uses_existing_account_values_as_defaults(self): + answers = iter(["login@example.com", "", "", "", "", "", ""]) + prompts = [] + password_prompts = [] + stderr = StringIO() + + def input_answer(prompt): + prompts.append(prompt) + return next(answers) + + def password_answer(prompt): + password_prompts.append(prompt) + return "" + + with TemporaryDirectory() as tmp: + with patch("perplexity.config.Path.home", return_value=Path(tmp)): + save_config( + { + "mail": { + "accounts": { + "login@example.com": { + "address": "mailbox@example.com", + "username": "imap-user", + "password": "secret", + "imap_hostname": "imap.example.com", + "port": 143, + "folder": "Archive/Perplexity", + "delete_signin_messages": False, + } + } + } + } + ) + + with patch("perplexity.config.input", side_effect=input_answer): + with patch("perplexity.config.getpass", side_effect=password_answer): + with patch("perplexity.config.IMAP4", FakeMailbox): + with patch("perplexity.config.stderr", stderr): + configure_mail() + + mail_config = load_config()["mail"]["accounts"]["login@example.com"] + self.assertEqual(mail_config["address"], "mailbox@example.com") + self.assertEqual(mail_config["username"], "imap-user") + self.assertEqual(mail_config["password"], "secret") + self.assertEqual(mail_config["imap_hostname"], "imap.example.com") + self.assertEqual(mail_config["port"], 143) + self.assertEqual(mail_config["folder"], "Archive/Perplexity") + self.assertFalse(mail_config["delete_signin_messages"]) + + self.assertIn("Email Address [mailbox@example.com]: ", prompts) + self.assertIn("Username [imap-user]: ", prompts) + self.assertIn("IMAP Server [imap.example.com]: ", prompts) + self.assertIn("Port [143]: ", prompts) + self.assertIn("Folder [Archive/Perplexity]: ", prompts) + self.assertIn("Delete Sign-in Messages [y/N]? ", prompts) + self.assertEqual(password_prompts, ["Password [configured]: "]) + self.assertNotIn("secret", "".join(prompts + password_prompts)) + self.assertIn("Configured accounts:\n - login@example.com\n\n", stderr.getvalue()) + self.assertIn("Validating IMAP login details...\nIMAP login details validated.\n", stderr.getvalue()) + + def test_configure_mail_lists_existing_accounts(self): + answers = iter(["new@example.com", "", "imap-user", "imap.example.com", "993", "", ""]) + stderr = StringIO() + + with TemporaryDirectory() as tmp: + with patch("perplexity.config.Path.home", return_value=Path(tmp)): + save_config( + { + "mail": { + "accounts": { + "z@example.com": {"address": "z@example.com"}, + "a@example.com": {"address": "a@example.com"}, + } + } + } + ) + + with patch("perplexity.config.input", side_effect=lambda _prompt: next(answers)): + with patch("perplexity.config.getpass", return_value="secret"): + with patch("perplexity.config.IMAP4_SSL", FakeMailbox): + with patch("perplexity.config.stderr", stderr): + configure_mail() + + self.assertIn( + "Configure the IMAP mail account which receives Perplexity Sign-in emails...\n\n" + "Configured accounts:\n" + " - a@example.com\n" + " - z@example.com\n\n", + stderr.getvalue(), + ) + + def test_mail_lookup_uses_configured_folder(self): + with patch("perplexity.config.IMAP4_SSL", EmptyMailbox): + self.assertIsNone( + _check_mailbox( + "login@example.com", + { + "username": "imap-user", + "password": "secret", + "imap_hostname": "imap.example.com", + "port": 993, + "folder": "Archive/Perplexity", + }, + ) + ) + + self.assertEqual(FakeMailbox.instances[0].selected_mailbox, "Archive/Perplexity") + + def test_mail_lookup_deletes_used_signin_message_when_configured(self): + with patch("perplexity.config.IMAP4_SSL", LoginMessageMailbox): + url = _check_mailbox( + "login@example.com", + { + "username": "imap-user", + "password": "secret", + "imap_hostname": "imap.example.com", + "port": 993, + "delete_signin_messages": True, + }, + ) + + mailbox = FakeMailbox.instances[0] + self.assertIn("token=543116", url) + self.assertEqual(mailbox.stored_message_id, b"123") + self.assertEqual(mailbox.store_command, "+FLAGS") + self.assertEqual(mailbox.store_flags, "\\Deleted") + self.assertTrue(mailbox.expunged) + + def test_mail_lookup_leaves_used_signin_message_by_default(self): + with patch("perplexity.config.IMAP4_SSL", LoginMessageMailbox): + self.assertIn( + "token=543116", + _check_mailbox( + "login@example.com", + { + "username": "imap-user", + "password": "secret", + "imap_hostname": "imap.example.com", + "port": 993, + }, + ), + ) + + self.assertFalse(hasattr(FakeMailbox.instances[0], "stored_message_id")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_login.py b/tests/test_login.py new file mode 100644 index 0000000..a13ab92 --- /dev/null +++ b/tests/test_login.py @@ -0,0 +1,52 @@ +from io import StringIO +from pathlib import Path +from tempfile import TemporaryDirectory +import unittest +from unittest.mock import patch + +from perplexity.perplexity import Perplexity + + +class FakeCookies: + def get_dict(self): + return {"session": "cookie"} + + +class FakeSession: + def __init__(self): + self.cookies = FakeCookies() + self.get_urls = [] + + def post(self, url, data): + self.post_url = url + self.post_data = data + + def get(self, url): + self.get_urls.append(url) + + +class LoginFallbackTest(unittest.TestCase): + def test_mail_failure_falls_back_to_manual_token(self): + with TemporaryDirectory() as tmp: + perplexity = Perplexity.__new__(Perplexity) + perplexity.session = FakeSession() + + with patch("perplexity.perplexity.config_dir", return_value=Path(tmp)): + with patch("perplexity.perplexity.mail_config_for", return_value={"address": "mailbox@example.com"}): + with patch( + "perplexity.perplexity.retrieve_login_url_from_mail", + side_effect=RuntimeError("imap failed"), + ): + with patch("sys.stdin", StringIO("543116\n")): + with patch("sys.stderr", StringIO()) as stderr: + perplexity._login("login@example.com") + + self.assertIn("Failed to retrieve token from email: imap failed", stderr.getvalue()) + self.assertEqual( + perplexity.session.get_urls[-1], + "https://www.perplexity.ai/api/auth/callback/email?callbackUrl=defaultMobileSignIn&email=login%40example.com&token=543116", + ) + + +if __name__ == "__main__": + unittest.main()