Skip to content

Commit

Permalink
implement requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nils-schmitt committed Jan 19, 2024
1 parent 7bce50b commit 33a6500
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 110 deletions.
8 changes: 8 additions & 0 deletions tracex/extraction/content/outputs/all_traces.csv
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ caseID,event_information,start_date,end_date,duration,event_type,attribute_locat
2,undergoing PCR test,20220617T0000,20220617T0000,00:00:00,Diagnosis,Home
2,receiving negative test results,20220617T0000,20220619T0000,48:00:00,Diagnosis,Doctors
2,getting infected and testing positive,20220617T0000,20220617T0000,00:00:00,Symptom Onset,Home
3,experiencing first Covid-19 symptoms,20220715T0000,20220729T0000,336:00:00,Symptom Onset,Home
3,isolating myself to prevent spread,20220715T0000,20220729T0000,336:00:00,Lifestyle Change,Home
3,"symptoms worsening: persistent cough, fever, fatigue",20220715T0000,20220718T0000,72:00:00,Seeking Medical Advice,Home
3,contacting primary care physician for advice,20220718T0000,20220718T0000,00:00:00,Doctor visit,Doctors
3,monitoring symptoms closely,20220715T0000,20220725T0000,240:00:00,Symptom Onset,Home
3,getting tested for Covid-19,20220718T0000,20220719T0000,24:00:00,Diagnosis,Hospital
3,continuing to isolate at home,20220715T0000,20220729T0000,336:00:00,Other,Home
3,taking over-the-counter medications for symptoms,20220715T0000,20220722T0000,168:00:00,Treatment,Home
10 changes: 5 additions & 5 deletions tracex/extraction/prototype/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"type": "function",
"function": {
"name": "add_start_dates",
"description": "this function extract the start date",
"description": "this function extracts the start date",
"parameters": {
"type": "object",
"properties": {
Expand All @@ -25,7 +25,7 @@
"type": "function",
"function": {
"name": "add_end_dates",
"description": "this function extract the end date",
"description": "this function extracts the end date",
"parameters": {
"type": "object",
"properties": {
Expand All @@ -45,7 +45,7 @@
"type": "function",
"function": {
"name": "add_duration",
"description": "this function extract the duration",
"description": "this function extracts the duration",
"parameters": {
"type": "object",
"properties": {
Expand All @@ -65,7 +65,7 @@
"type": "function",
"function": {
"name": "add_event_type",
"description": "this function extract the event type",
"description": "this function extracts the event type",
"parameters": {
"type": "object",
"properties": {
Expand All @@ -85,7 +85,7 @@
"type": "function",
"function": {
"name": "add_location",
"description": "this function extract the location",
"description": "this function extracts the location",
"parameters": {
"type": "object",
"properties": {
Expand Down
126 changes: 31 additions & 95 deletions tracex/extraction/prototype/input_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,10 @@ def convert_text_to_bulletpoints(text):
df = pd.DataFrame([], columns=["event_information"])
bulletpoints = bulletpoints.replace("- ", "")
bulletpoints = bulletpoints.split("\n")
for i in bulletpoints:
new_row = pd.DataFrame([i], columns=["event_information"])
for row in bulletpoints:
new_row = pd.DataFrame([row], columns=["event_information"])
df = pd.concat([df, new_row], ignore_index=True)
with open(
(u.output_path / "intermediates/bulletpoints.txt"),
"w",
) as f:
f.write("\n")
document_intermediates("\n", True)
return df


Expand All @@ -75,7 +71,6 @@ def add_start_dates(text, df):
name = "start_date"
new_df = pd.DataFrame([], columns=[name])
values_list = df.values.tolist()
i = 0
for item in values_list:
messages = [
{"role": "system", "content": p.START_DATE_CONTEXT},
Expand All @@ -89,9 +84,7 @@ def add_start_dates(text, df):
},
{"role": "assistant", "content": p.START_DATE_ANSWER},
]

output = u.query_gpt(messages)

fc_message = [
{"role": "system", "content": p.FC_START_DATE_CONTEXT},
{"role": "user", "content": p.FC_START_DATE_PROMPT + "The text: " + output},
Expand All @@ -103,21 +96,13 @@ def add_start_dates(text, df):
new_row = pd.DataFrame([start_date], columns=[name])
new_df = pd.concat([new_df, new_row], ignore_index=True)
row_count = new_df.shape[0]

if start_date == "N/A" and row_count > 1:
last_index = new_df.index[-1]
previous_index = last_index - 1
new_df.at[last_index, "start_date"] = new_df.at[
previous_index, "start_date"
]

print(name + ": " + str(i) + " ", end="\r")
i = i + 1
with open(
(u.output_path / "intermediates/bulletpoints.txt"),
"a",
) as f:
f.write("\n" + output)
document_intermediates(output)
df = pd.concat([df, new_df], axis=1)
return df

Expand All @@ -127,7 +112,6 @@ def add_end_dates(text, df):
name = "end_date"
new_df = pd.DataFrame([], columns=[name])
values_list = df.values.tolist()
i = 0
for item in values_list:
messages = [
{"role": "system", "content": p.END_DATE_CONTEXT},
Expand All @@ -144,7 +128,6 @@ def add_end_dates(text, df):
{"role": "assistant", "content": p.END_DATE_ANSWER},
]
output = u.query_gpt(messages)

fc_message = [
{"role": "system", "content": p.FC_END_DATE_CONTEXT},
{"role": "user", "content": p.FC_END_DATE_PROMPT + "The text: " + output},
Expand All @@ -155,80 +138,26 @@ def add_end_dates(text, df):
)
new_row = pd.DataFrame([end_date], columns=[name])
new_df = pd.concat([new_df, new_row], ignore_index=True)
print(name + ": " + str(i) + " ", end="\r")
i = i + 1
with open(
(u.output_path / "intermediates/bulletpoints.txt"),
"a",
) as f:
f.write("\n" + output)
document_intermediates(output)
df = pd.concat([df, new_df], axis=1)
return df


# def add_durations(text, df):
# """Adds durations to the bulletpoints."""
# name = "duration"
# new_df = pd.DataFrame([], columns=[name])
# values_list = df.values.tolist()
# i = 0
# for item in values_list:
# messages = [
# {"role": "system", "content": p.DURATION_CONTEXT},
# {
# "role": "user",
# "content": p.DURATION_PROMPT
# + "\nThe text: "
# + text
# + "\nThe bulletpoint: "
# + item[0]
# + "\nThe start date: "
# + item[1]
# + "\nThe end date: "
# + item[2],
# },
# {"role": "assistant", "content": p.DURATION_ANSWER},
# ]
# output = u.query_gpt(messages)

# fc_message = [
# {"role": "system", "content": p.FC_DURATION_CONTEXT},
# {"role": "user", "content": p.FC_DURATION_PROMPT + "The text: " + output},
# ]
# duration = u.query_gpt(
# fc_message,
# tool_choice={"type": "function", "function": {"name": "add_duration"}},
# )
# new_row = pd.DataFrame([duration], columns=[name])
# new_df = pd.concat([new_df, new_row], ignore_index=True)
# print(name + ": " + str(i) + " ", end="\r")
# i = i + 1
# with open(
# (u.output_path / "intermediates/bulletpoints.txt"),
# "a",
# ) as f:
# f.write("\n" + output)
# df = pd.concat([df, new_df], axis=1)
# return df


def add_durations(df):
"""Funktion zur Berechnung der Dauer im gewünschten Format"""
"""Calculates and adds the duration for every event information."""

def calculate_row_duration(row):
if row["start_date"] == "N/A" or row["end_date"] == "N/A":
return "N/A"

start_date = datetime.strptime(row["start_date"], "%Y%m%dT%H%M")
end_date = datetime.strptime(row["end_date"], "%Y%m%dT%H%M")
duration = end_date - start_date
hours, remainder = divmod(duration.total_seconds(), 3600)
minutes, seconds = divmod(remainder, 60)
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"

# Neue Spalte 'duration' erstellen und für jede Zeile die Dauer berechnen
# create new column 'duration' and calculate the duration for every row
df["duration"] = df.apply(calculate_row_duration, axis=1)

return df


Expand All @@ -237,7 +166,6 @@ def add_event_types(df):
name = "event_type"
new_df = pd.DataFrame([], columns=[name])
values_list = df.values.tolist()
i = 0
for item in values_list:
messages = [
{"role": "system", "content": p.EVENT_TYPE_CONTEXT},
Expand All @@ -259,13 +187,7 @@ def add_event_types(df):
)
new_row = pd.DataFrame([event_type], columns=[name])
new_df = pd.concat([new_df, new_row], ignore_index=True)
print(name + ": " + str(i) + " ", end="\r")
i = i + 1
with open(
(u.output_path / "intermediates/bulletpoints.txt"),
"a",
) as f:
f.write("\n" + output)
document_intermediates(output)
df = pd.concat([df, new_df], axis=1)
return df

Expand All @@ -276,7 +198,6 @@ def add_locations(df):
new_df = pd.DataFrame([], columns=[name])
values_list = df.values.tolist()
event_type_key = df.columns.get_loc("event_type")
i = 0
for item in values_list:
print(item[0], end="\r")
messages = [
Expand All @@ -302,24 +223,19 @@ def add_locations(df):
)
new_row = pd.DataFrame([location], columns=[name])
new_df = pd.concat([new_df, new_row], ignore_index=True)
print(name + ": " + str(i) + " ", end="\r")
i = i + 1
with open(
(u.output_path / "intermediates/bulletpoints.txt"),
"a",
) as f:
f.write("\n\n" + output)
document_intermediates(output)
df = pd.concat([df, new_df], axis=1)
return df


def convert_dataframe_to_csv(df):
"""Converts the dataframe to CSV and save it on disk."""
"""Converts the dataframe to CSV and saves it on disk."""
output_path = u.output_path / "single_trace.csv"
df.insert(loc=0, column="case_id", value="0")
df.to_csv(
path_or_buf=output_path, sep=",", encoding="utf-8", header=True, index=False
)
document_intermediates(df, is_dataframe=True)
return output_path


Expand Down Expand Up @@ -361,3 +277,23 @@ def append_csv():
def farewell():
"""Prints a farewell message."""
print("-----------------------------------\nThank you for using TracEX!\n\n")


def document_intermediates(text, is_first=False, is_dataframe=False):
"""Writes the text to a file."""
if is_dataframe:
with open((u.output_path / "intermediates/bulletpoints.txt"), "a") as f:
text = text.to_string(header=False, index=False)
f.write("\n\n\n\nThe resulting Dataframe:\n\n" + text)
if is_first:
with open(
(u.output_path / "intermediates/bulletpoints.txt"),
"w",
) as f:
f.write(text)
else:
with open(
(u.output_path / "intermediates/bulletpoints.txt"),
"a",
) as f:
f.write("\n" + text)
10 changes: 0 additions & 10 deletions tracex/extraction/prototype/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,6 @@ def get_decision(question):
return get_decision(question)


# def query_gpt(messages, temperature=TEMPERATURE_SUMMARIZING):
# """Queries the GPT engine."""
# response = client.chat.completions.create(model=MODEL,
# messages=messages,
# max_tokens=MAX_TOKENS,
# temperature=temperature)
# output = response.choices[0].message.content
# return output


def query_gpt(
messages, tools=fc.TOOLS, tool_choice="none", temperature=TEMPERATURE_SUMMARIZING
):
Expand Down

0 comments on commit 33a6500

Please sign in to comment.