Skip to content

Commit

Permalink
Improve chat functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
redadmiral committed Mar 27, 2024
1 parent 6af815c commit dc96741
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ be accessed anymore. The `expire_on_commit=False` parameter disables this behavi
- Fixed a bug with which it was not possible to init the model without passing a auth token
- IGEL now also accepts "IGEL_URL" to query the llm endpoint.

## 0.1.4.2 -> 0.1.5
## 0.1.4.2 -> 0.1.4.3
- Closed issues [2](https://github.com/br-data/rag-tools-library/issues/2) and [10](https://github.com/br-data/rag-tools-library/issues/10)
by considering the max_new_tokens param in the `fit_to_context_window()` method.
by considering the max_new_tokens param in the `fit_to_context_window()` method.

## 0.1.4.3 -> 0.1.5
- Added tests for FAISS DB
- Improve handling of History and Chat functionality
- TODO: Write Documentation for this point.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "brdata-rag-tools"
version = "0.1.4.3"
version = "0.1.5"
authors = [
{ name = "Marco Lehner", email = "[email protected]" },
]
Expand Down
46 changes: 36 additions & 10 deletions src/brdata_rag_tools/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
from aenum import Enum, extend_enum
import tiktoken
Expand Down Expand Up @@ -293,22 +294,42 @@ def prompt(self, prompt: str) -> str:

return gen_response.choices[0].message.content

def chat(self, prompt: str) -> str:
def chat(self, prompt: str, stream: bool, n_history: int = 2, history: List[Dict] = None) -> str:

self.history.add(Role.USER, prompt)

if history is None:
injection_history = self.history.get_content(last_n=n_history)
else:
injection_history = copy.deepcopy(history)
injection_history.extend(self.history.get_content(last_n=1))
injection_history = injection_history[-n_history:]


gen_response = self.client.chat.completions.create(
model=self.model.value,
messages=self.history.get_content(),
messages=injection_history,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
top_p=self.top_p,
n=self.number_of_responses,
stream=stream
)

content = gen_response.choices[0].message.content
self.history.add(Role.SYSTEM, content)
return content
if stream is False:
content = gen_response.choices[0].message.content
self.history.add(Role.SYSTEM, content)
return content
else:
history = ""
for chunk in gen_response:
chunk = chunk.choices[0].delta.content
if chunk is not None:
history += chunk

yield(chunk)

self.history.add(Role.SYSTEM, history)

class IGEL(Generator):
"""
Expand Down Expand Up @@ -411,8 +432,8 @@ def prompt(self, prompt: str) -> str:
"""
return self.model.prompt(prompt)

def chat(self, prompt: str) -> str:
return self.model.chat(prompt)
def chat(self, prompt: str, stream: bool = False, n_history: int = 2, history: List[Dict] = None) -> str:
return self.model.chat(prompt, stream = stream, n_history = n_history, history = history)

def new_chat(self):
self.model.history.reset()
Expand All @@ -434,8 +455,13 @@ def add(self, role: Role, message: str):
def reset(self):
self._history = []

def get_content(self):
def get_content(self, last_n: int = 2):
last_n += 1 # the last one is the actual prompt, not history.
history = self._history[-last_n:]

if self.model_name.family == "GPT":
return self._history
return history
else:
return [f"{x['role']}: {x['content']}" for x in self._history]
return [f"{x['role']}: {x['content']}" for x in history]


0 comments on commit dc96741

Please sign in to comment.