-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathplotgpt.py
More file actions
120 lines (96 loc) · 4.43 KB
/
plotgpt.py
File metadata and controls
120 lines (96 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import pandas as pd
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
import re
from typing import Tuple
import tiktoken
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
# copied from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo":
print(
"Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301."
)
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
elif model == "gpt-4":
print(
"Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314."
)
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
elif model == "gpt-4-0314":
tokens_per_message = 3
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
num_tokens += len(encoding.encode(message.content))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def _parse_code(content: str) -> str:
code_regex = r"```(?:python)?\n([\s\S]*)```"
match = re.search(code_regex, content)
code = match.group(1)
return code
def _pandas_dtype_str(df: pd.DataFrame) -> str:
return " ".join([f"{col}({dtype})" for col, dtype in df.dtypes.items()])
_SYSTEM_TEMPLATE = "You are an expert Python data analyst. You have been given a dataframe with the following columns: `{dtype_str}`"
_MAX_TOKENS = 4096
_MAX_HISTORY_CONTEXT = 50
_MODEL = "gpt-3.5-turbo"
class PlotGPT:
def __init__(self, show_code: bool = True) -> None:
self._llm = ChatOpenAI(model_name=_MODEL)
self._history = []
self._system_prompt = None
self._show_code = show_code
def _clear_history(self) -> None:
self._history = []
self._system_prompt = None
def inspect(self, df: pd.DataFrame) -> None:
self._clear_history()
self._system_prompt = SystemMessage(
content=_SYSTEM_TEMPLATE.format(dtype_str=_pandas_dtype_str(df))
)
self._df = df
def _construct_messages(self, new_msg: HumanMessage):
num_remaining_tokens = _MAX_TOKENS - num_tokens_from_messages(
[self._system_prompt, new_msg], model=_MODEL
)
if num_remaining_tokens < 0:
raise ValueError(f"prompt is too long for {_MAX_TOKENS} limit!")
history = self._history[-1 * _MAX_HISTORY_CONTEXT :]
while num_tokens_from_messages(history) > num_remaining_tokens:
history = history[2:]
return [self._system_prompt] + history + [new_msg]
def _get_response(self, prompt) -> Tuple[AIMessage, HumanMessage]:
assert self._system_prompt is not None, "Inspect a dataframe first!"
new_msg = HumanMessage(
content=f"Give me the only python code and nothing else for the following. Only use matplotlib and seaborn. Assume I have the dataframe preloaded as `df`: {prompt}"
)
messages = self._construct_messages(new_msg)
resp = self._llm(messages)
return resp, new_msg
def ask(self, prompt) -> None:
assert self._system_prompt is not None, "Need to inspect a dataframe first!"
ai_response, msg = self._get_response(prompt)
self._history += [msg, ai_response]
code = _parse_code(ai_response.content)
if self._show_code:
print(code)
# TODO: try/ except plot
exec(code, {"df": self._df})
# ai.ask("plot sepal width vs sepal length")
# ai.ask("now color it by species")
# ai.ask("make separate sepal width vs sepal length scatterplot subplots per species. Combine it into a single figure")