Skip to content

Commit 5b0be2c

Browse files
committed
Finishing generation + running_eval
1 parent 95f1945 commit 5b0be2c

19 files changed

+4934
-47
lines changed

agent_factory.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,25 @@
22
from llm_engines import LLMApi
33
import json
44
import random
5+
import itertools
6+
7+
# set up logging
8+
import logging
9+
10+
# Create a custom logger
11+
angent_factory_logger = logging.getLogger('agent_factory')
12+
angent_factory_logger.setLevel(logging.INFO)
13+
14+
# Create handlers
15+
agent_factory_file_handler = logging.FileHandler('agent_factory.log', mode='w')
16+
agent_factory_file_handler.setLevel(logging.INFO)
17+
18+
# Create formatters and add it to handlers
19+
formatter = logging.Formatter('%(asctime)-15s %(message)s')
20+
agent_factory_file_handler.setFormatter(formatter)
21+
22+
# Add handlers to the logger
23+
angent_factory_logger.addHandler(agent_factory_file_handler)
524

625
def read_personas():
726
list_personas=[]
@@ -36,11 +55,22 @@ def gen_name(persona, neutral_llm=LLMApi()):
3655
name=""
3756
while len(name)<2 or len(name)>20:
3857
name=neutral_llm.generate_response(f"Generate a name for me appropriate for a persona with the following description: {persona}.\nGenerate only a valid first name and nothing else. If your answer contains anything more than a first name, you will be terminated. Follow the name with ## to indicate the end of the name.")
58+
# if something like "Sure i can" is in the name, start again
59+
if "sure i can" in name.lower() and "can" in name.lower():
60+
agent_factory_log.info(f"Invalid name: {name}")
61+
name=""
62+
continue
3963
if "##" in name:
4064
name=name.split("##")[0]
4165
name=name.strip()
42-
print(f"Generated name: {name}")
66+
if len(name)<2 or len(name)>20:
67+
agent_factory_logger.info(f"Invalid name: {name}")
68+
name=""
69+
continue
4370
break
71+
else:
72+
agent_factory_log.info(f"Invalid name: {name}")
73+
name=""
4474
return name
4575

4676
def get_persona_by_topics(topics:list, domain=None):
@@ -140,4 +170,18 @@ def create_groupchat(topics_to_include, n_agents=2, agent_type=DialogueReactAgen
140170

141171
## check the topics of the agents
142172
print(f"Group chat topics: {gc_covered_topics}")
173+
# check that agents have different personas
174+
for agents_comb in itertools.combinations(agents, 2):
175+
while agents_comb[0].persona==agents_comb[1].persona:
176+
print("Agents have the same persona")
177+
# substitute the persona of the second agent
178+
agent=get_agent_by_topics(topics, agent_type=agent_type, neutral_llm=neutral_llm, **agent_args)
179+
# same for names
180+
while agents_comb[0].name==agents_comb[1].name:
181+
print("Agents have the same name")
182+
agent.name=gen_name(agent.persona, neutral_llm=neutral_llm)
183+
# substitute the persona instead of $name$ in the persona desc
184+
agent.persona=agent.persona.replace("$name$", agent.name)
185+
186+
143187
return agents

dialogue_react_agent.py

+36-18
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,23 @@ def read_dialogue_act_list(file_path:str):
4747
dialogue_acts_list=read_dialogue_act_list("prompts/dialogue_acts.json")
4848

4949

50-
logging.basicConfig(level=logging.INFO, filename="chat.log", filemode="w", format="%(asctime)-15s %(message)s")
50+
# Create a custom logger
51+
internal_agent_logger = logging.getLogger('agent_internal')
52+
internal_agent_logger.setLevel(logging.INFO)
5153

54+
# Create handlers
55+
internal_agent_file_handler = logging.FileHandler('agent_internal.log', mode='w')
56+
internal_agent_file_handler.setLevel(logging.INFO)
5257

58+
# Create formatters and add it to handlers
59+
formatter = logging.Formatter('%(asctime)-15s %(message)s')
60+
internal_agent_file_handler.setFormatter(formatter)
61+
62+
# Add handlers to the internal_agent_logge
63+
internal_agent_logger.addHandler(internal_agent_file_handler)
64+
65+
# prevent logging from propagating to the root logger
66+
internal_agent_logger.propagate = False
5367

5468

5569
class DialogueReactAgent(ReflectingAgent):
@@ -134,7 +148,7 @@ def gen_memories(self, last_messages, n_memories, **kwargs):
134148

135149

136150
# log compiled memory prompt
137-
logging.info(f"Memory prompt: {memory_prompt}")
151+
internal_agent_logger.info(f"Memory prompt: {memory_prompt}")
138152

139153
## generate until you get the desired number of memories
140154
memories = []
@@ -202,7 +216,7 @@ def gen_reflections(self, n_memories, n_reflections, **kwargs):
202216
agent_list=agent_list)
203217

204218
# log compiled reflection prompt
205-
logging.info(f"Reflection prompt: {reflection_prompt}")
219+
internal_agent_logger.info(f"Reflection prompt: {reflection_prompt}")
206220

207221
## generate until you get the desired number of reflections
208222
reflections = []
@@ -292,7 +306,7 @@ def get_answer(self, last_messages, extra_context="", **kwargs):
292306
)
293307

294308
# log rendered prompt
295-
logging.info(f"Answer generation prompt: {prompt}")
309+
internal_agent_logger.info(f"Answer generation prompt: {prompt}")
296310

297311
# get the answer for the dialogue react method
298312
if self.ablation == None:
@@ -312,20 +326,24 @@ def get_answer(self, last_messages, extra_context="", **kwargs):
312326
## check that thought is a dialogue act
313327
observation = observation[0].strip()
314328
thought = thought[0].strip()
329+
# remove everything after the first > in the thought
330+
thought = thought.split(">")[0]
331+
# readd the > at the end
332+
thought = thought + ">"
315333
action = action[0].strip()
316334
if thought not in self.dialogue_acts:
317-
logging.info(f"Invalid dialogue act: {thought}.")
335+
internal_agent_logger.info(f"Invalid dialogue act: {thought}.")
318336
continue
319337
answer = action
320338
#print(f"Valid answer: {answer_candidate}.")
321-
logging.info(f"Valid dialogue_act: {thought}.")
322-
logging.info(f"Valid answer: {answer}.")
339+
internal_agent_logger.info(f"Valid dialogue_act: {thought}.")
340+
internal_agent_logger.info(f"Valid answer: {answer}.")
323341
break
324342
else:
325-
logging.info(f"Invalid answer: {answer_candidate}.")
343+
internal_agent_logger.info(f"Invalid answer: {answer_candidate}.")
326344
except Exception as e:
327-
logging.info(f"Error generating response: {e}.")
328-
logging.info(f"Invalid answer: {answer_candidate}.")
345+
internal_agent_logger.info(f"Error generating response: {e}.")
346+
internal_agent_logger.info(f"Invalid answer: {answer_candidate}.")
329347

330348
answer_with_dialogueAct = f"Following the observation: {observation}, I wanted to commit the following dialogue act: {thought}. Therefore, I wrote the message: {action}."
331349
## save the answer in memory
@@ -350,13 +368,13 @@ def get_answer(self, last_messages, extra_context="", **kwargs):
350368
thought = thought[0].strip()
351369
action = action[0].strip()
352370
answer = action
353-
logging.info(f"Valid answer: {answer}.")
371+
internal_agent_logger.info(f"Valid answer: {answer}.")
354372
break
355373
else:
356-
logging.info(f"Invalid answer: {answer_candidate}.")
374+
internal_agent_logger.info(f"Invalid answer: {answer_candidate}.")
357375
except Exception as e:
358-
logging.info(f"Error generating response: {e}.")
359-
logging.info(f"Invalid answer: {answer_candidate}.")
376+
internal_agent_logger.info(f"Error generating response: {e}.")
377+
internal_agent_logger.info(f"Invalid answer: {answer_candidate}.")
360378
# generate answer for the ablation no_dialogue_no_react
361379
elif self.ablation == "no_dialogue_no_react":
362380
# in this case, the message is a string that ends with ##
@@ -367,13 +385,13 @@ def get_answer(self, last_messages, extra_context="", **kwargs):
367385
answer = answer_candidate
368386
# remove the ##
369387
answer = answer[:-2]
370-
logging.info(f"Valid answer: {answer}.")
388+
internal_agent_logger.info(f"Valid answer: {answer}.")
371389
break
372390
else:
373-
logging.info(f"Invalid answer: {answer_candidate}.")
391+
internal_agent_logger.info(f"Invalid answer: {answer_candidate}.")
374392
except Exception as e:
375-
logging.info(f"Error generating response: {e}.")
376-
logging.info(f"Invalid answer: {answer_candidate}.")
393+
internal_agent_logger.info(f"Error generating response: {e}.")
394+
internal_agent_logger.info(f"Invalid answer: {answer_candidate}.")
377395
## save the answer in memory
378396
self.memory.upsert([{"turn_count":turn_count, "text":answer}])
379397

groupchat_thread.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,27 @@
44

55
eval_prompt=load_base_prompt("prompts/chat_evaluation.j2")
66

7-
logging.basicConfig(level=logging.INFO, filename="chat.log", filemode="w", format="%(asctime)-15s %(message)s")
87

9-
class GroupchatThread:
8+
# Create a custom logger
9+
group_chat_logger = logging.getLogger('chat_thread')
10+
group_chat_logger.setLevel(logging.INFO)
11+
12+
# Create handlers
13+
group_chat_file_handler = logging.FileHandler('chat.log', mode='w')
14+
group_chat_file_handler.setLevel(logging.INFO)
15+
16+
# Create formatters and add it to handlers
17+
formatter = logging.Formatter('%(asctime)-15s %(message)s')
18+
group_chat_file_handler.setFormatter(formatter)
19+
20+
# Add handlers to the logger
21+
group_chat_logger.addHandler(group_chat_file_handler)
22+
23+
# prevent logging from propagating to the root logger
24+
group_chat_logger.propagate = False
25+
26+
27+
class GroupChatThread:
1028
"""
1129
A group chat thread that simulates a conversation between multiple agents.
1230
"""
@@ -69,7 +87,7 @@ def start_conversation(self):
6987
assert self.turn == 0, "The conversation has already started."
7088

7189
# log the start of the conversation and agent list
72-
logging.info(f"Starting conversation {self.chat_id} with agents: {[agent.name for agent in self.agent_list]}")
90+
group_chat_logger.info(f"Starting conversation {self.chat_id} with agents: {[agent.name for agent in self.agent_list]}")
7391

7492
random_agent = self.pick_random_agent()
7593
first_message = (1, random_agent.name, random.choice(self.conversation_starters))
@@ -142,7 +160,7 @@ def evaluate_chat(self, start_index=0, end_index=-1):
142160
if eval_score >= 0 and eval_score <= 10:
143161
eval_results.append(eval_score)
144162
except Exception as e:
145-
logging.error(f"Error in evaluation response: {eval_result}")
163+
group_chat_logger.error(f"Error in evaluation response: {eval_result}")
146164
print(e)
147165

148166
eval_score = sum(eval_results) / 3
@@ -221,7 +239,7 @@ def run_chat(self, max_turns=50):
221239
raise ValueError("The selection method is not valid.")
222240

223241
# log the end of the conversation
224-
logging.info(f"Ending conversation {self.chat_id} after {self.turn} turns.")
242+
group_chat_logger.info(f"Ending conversation {self.chat_id} after {self.turn} turns.")
225243
# at the end of the chat, evaluate the chat
226244
if self.n_eval != -1:
227245
self.evaluate_chat()

llm_engines.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818

1919
url= "http://127.0.0.1:1200/v1/chat/completions"
2020

21+
model_info_url= "http://127.0.0.1:1200/v1/internal/model/info"
22+
23+
load_model_url = "http://127.0.0.1:1200/v1/internal/model/load"
24+
2125
headers = {
2226
"Content-Type": "application/json"
2327
}
@@ -83,18 +87,30 @@ def __init__(self, history=[], model="turboderp_Mixtral-8x7B-instruct-exl2_5.0bp
8387
self.url = url
8488
self.headers = headers
8589
self.model = model
90+
91+
def get_current_model(self):
92+
# send a request to get the current model
93+
response = requests.get(model_info_url, headers=self.headers, verify=False)
94+
return response.json()["model_name"]
95+
96+
8697
def generate_response(self, user_prompt):
8798

8899
assistant_message = ""
89100

101+
# make sure the model is loaded
102+
if self.get_current_model() != self.model:
103+
data = {
104+
"model_name": self.model
105+
}
106+
response = requests.post(load_model_url, headers=self.headers, json=data, verify=False)
90107

91108
# append user prompt to history
92109
self.history.append({"role":"user","content": user_prompt})
93110
# query api for response
94111
data = {
95112
"mode": "instruct",
96113
"messages": self.history,
97-
"model": self.model,
98114
}
99115
response = requests.post(url, headers=headers, json=data, verify=False, )
100116

File renamed without changes.

places_replication.py

+43-11
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random, logging, time,json, os, textwrap
55

66

7-
logging.basicConfig(level=logging.INFO, filename="chat.log", filemode="w", format="%(asctime)-15s %(message)s")
7+
logging.basicConfig(level=logging.INFO, filename="chat.log", filemode="w", format="%(asctime)-15s %(message)s", force=True)
88

99

1010
# few-shot examples for the PLACES paper are contained in a jsonl file at prompts/places_examples.jsonl, let's load them, eacj line is a few-shot example conversation
@@ -97,7 +97,33 @@ def dump_chat(self):
9797
os.makedirs("chat_logs")
9898
with open(f"chat_logs/{self.chat_id}.json", "w") as f:
9999
json.dump(chat_data, f)
100+
101+
def summarize_personas_into_oneliner(self):
102+
"""
103+
Summarizes the personas of the agents into a one-liner.
104+
"""
105+
# get the unique personas
106+
personas=set([agent.persona for agent in self.agent_list])
107+
108+
# make it a string
109+
personas_string = ". ".join(personas)
110+
111+
examples="The following is a conversation between Alice and Bob about grocery shopping. Alice has a shopping list for Bob.\nThe following is a conversation between Alice and Bob about relationships. Bob recently got engaged.\nThe following is a conversation between Alice and Bob about their hobbies. Alice enjoys tennis and Bob likes playing soccer."
112+
113+
prompt = f"Generate a sentence like the following one line summary of a convesation premise, but using the personas of the agents in the conversation. Remember to always use 'The following is a conversation between' and use the exact names from the personas. Include some details from topics the speakers are interested in based on their personas.\n Personas:\n\n{personas_string}\n\nFew-shot examples:\n{examples}. Answer (one line ending with ##):"
100114

115+
answer = ""
116+
117+
# keep asking until the answer is valid
118+
while len(answer)<10 or len(answer)>100:
119+
answer = self.neutral_llm.generate_response(prompt)
120+
if "##" in answer and "The following is a conversation between" in answer and self.agent_list[0].name in answer and self.agent_list[1].name in answer:
121+
answer=answer.split("##")[0]
122+
answer=answer.strip()
123+
return answer
124+
else:
125+
logging.info(f"Invalid answer: {answer}")
126+
answer = ""
101127

102128
def generate_conversation(self, min_turns :int=10, start_conversation:bool=True):
103129
"""
@@ -124,20 +150,12 @@ def generate_conversation(self, min_turns :int=10, start_conversation:bool=True)
124150
conversation_starter_string = f"{speaker.name}: {conversation_starter}"
125151
else:
126152
conversation_starter_string = ""
153+
127154

128-
129155

130156

131-
# fill out the prompt template
132-
prompt = self.generation_template.render(
133-
agent_names=[agent.name for agent in self.agent_list],
134-
agent_personas=[agent.persona for agent in self.agent_list],
135-
conversation_starter_string=conversation_starter_string,
136-
first_speaker=speaker.name,
137-
few_shot_examples=random.sample(self.few_shot_examples,3)
138-
)
157+
139158

140-
logging.info(f"Naive generation prompt: {prompt}")
141159

142160
# the answer should a list of strings, one message per line until the max_turns is reached
143161

@@ -146,21 +164,35 @@ def generate_conversation(self, min_turns :int=10, start_conversation:bool=True)
146164
# keep generating until you have a conversation of min max_turns length
147165

148166
while self.turn < min_turns:
167+
# fill out the prompt template
168+
prompt = self.generation_template.render(
169+
agent_names=[agent.name for agent in self.agent_list],
170+
conversation_premise=self.summarize_personas_into_oneliner(),
171+
conversation_starter_string=conversation_starter_string,
172+
first_speaker=speaker.name,
173+
few_shot_examples=random.sample(self.few_shot_examples,3)
174+
)
175+
logging.info(f"Naive generation prompt: {prompt}")
176+
149177
self.turn = 1 if start_conversation else 0
150178

151179
# reset the chat history for each turn
152180
temp_chat_history = []
153181

154182
# generate the answer
155183
answer = self.neutral_llm.generate_response(prompt)
184+
logging.info(f"Naive generation answer: {answer}")
156185

186+
# remove double newlines just in case
187+
answer = answer.replace("\n\n", "\n")
157188
# split the answer into lines, each line should start with the agent name and a colon
158189
answer_lines = answer.split("\n")
159190

160191
# add the answer to the chat history
161192
for line in answer_lines:
162193
# split the line into agent name and message only at the first colon
163194
try:
195+
line=line.strip()
164196
agent_name, message = line.split(":", 1)
165197
if agent_name in [agent.name for agent in self.agent_list]:
166198
temp_chat_history.append((self.turn, agent_name, message))

0 commit comments

Comments
 (0)