Skip to content

Commit

Permalink
♻️ use dataframe framework
Browse files Browse the repository at this point in the history
  • Loading branch information
FR-SON committed Jan 23, 2024
1 parent 92f40b9 commit 1a129a4
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 236 deletions.
2 changes: 1 addition & 1 deletion tracex/extraction/logic/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
MAX_TOKENS = 1100
TEMPERATURE_SUMMARIZING = 0
TEMPERATURE_CREATION = 1
CSV_OUTPUT = settings.BASE_DIR / "extraction/content/outputs/intermediates/7_output.csv"
CSV_OUTPUT = settings.BASE_DIR / "extraction/content/outputs/single_trace.csv"
CSV_ALL_TRACES = settings.BASE_DIR / "extraction/content/outputs/all_traces.csv"
104 changes: 104 additions & 0 deletions tracex/extraction/logic/function_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# pylint: disable=line-too-long
"""Module providing a functions for using OpenAI function calling."""
TOOLS = [
{
"type": "function",
"function": {
"name": "add_start_dates",
"description": "this function extracts the start date",
"parameters": {
"type": "object",
"properties": {
"output": {
"type": "array",
"items": {
"type": "string",
"description": "a date in the format YYYYMMDDT0000 and if not available N/A",
},
},
},
"required": ["output"],
},
},
},
{
"type": "function",
"function": {
"name": "add_end_dates",
"description": "this function extracts the end date",
"parameters": {
"type": "object",
"properties": {
"output": {
"type": "array",
"items": {
"type": "string",
"description": "a end date in the format YYYYMMDDT0000",
},
},
},
"required": ["output"],
},
},
},
{
"type": "function",
"function": {
"name": "add_duration",
"description": "this function extracts the duration",
"parameters": {
"type": "object",
"properties": {
"output": {
"type": "array",
"items": {
"type": "string",
"description": "a duration in the format HHH:MM:SS or HH:MM:SS",
},
},
},
"required": ["output"],
},
},
},
{
"type": "function",
"function": {
"name": "add_event_type",
"description": "this function extracts the event type",
"parameters": {
"type": "object",
"properties": {
"output": {
"type": "array",
"items": {
"type": "string",
"description": "an event type (one of 'Symptom Onset', 'Symptom Offset', 'Diagnosis', 'Doctor visit', 'Treatment', 'Hospital stay', 'Medication', 'Lifestyle Change' and 'Feelings')",
},
},
},
"required": ["output"],
},
},
},
{
"type": "function",
"function": {
"name": "add_location",
"description": "this function extracts the location",
"parameters": {
"type": "object",
"properties": {
"output": {
"type": "array",
"items": {
"type": "string",
"description": "a location (one of 'Home', 'Hospital', 'Doctors' and 'Other')",
},
},
},
"required": ["output"],
},
},
},
]
31 changes: 8 additions & 23 deletions tracex/extraction/logic/modules/module_activity_labeler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pandas as pd
from ..module import Module
from .. import utils as u
from .. import prompts as p

from pandas import DataFrame


class ActivityLabeler(Module):
"""
Expand All @@ -20,34 +19,20 @@ def execute(self, _input, patient_journey=None):
super().execute(_input, patient_journey)
self.result = self.__extract_activities()

# TODO: Convert to dataframes
def __extract_activities(self):
"""Converts the input text to activity_labels."""
name = "event_information"
messages = [
{"role": "system", "content": p.TXT_TO_BULLETPOINTS_CONTEXT},
{
"role": "user",
"content": p.TXT_TO_BULLETPOINTS_PROMPT + self.patient_journey,
"content": f"{p.TXT_TO_BULLETPOINTS_PROMPT} {self.patient_journey}",
},
{"role": "assistant", "content": p.TXT_TO_BULLETPOINTS_ANSWER},
]
activity_labels = u.query_gpt(messages)
activity_labels = self.__remove_commas(activity_labels)
activity_labels = self.__add_ending_commas(activity_labels)
with open((u.output_path / "intermediates/1_bulletpoints.txt"), "w") as f:
f.write(activity_labels)
return activity_labels

@staticmethod
def __remove_commas(activity_labels):
"""Removes commas from within the activity_labels."""
activity_labels = activity_labels.replace(", ", "/")
activity_labels = activity_labels.replace(",", "/")
return activity_labels

@staticmethod
def __add_ending_commas(activity_labels):
"""Adds commas at the end of each line."""
activity_labels = activity_labels.replace("\n", ",\n")
activity_labels = activity_labels + ","
return activity_labels
# TODO: adjust prompt to remove "-" instead of replace()
activity_labels = activity_labels.replace("- ", "").split("\n")
df = pd.DataFrame(activity_labels, columns=[name])
# document_intermediates("\n", True)
return df
60 changes: 35 additions & 25 deletions tracex/extraction/logic/modules/module_event_type_classifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pandas as pd

from ..module import Module
from .. import utils as u
from .. import prompts as p


from pandas import DataFrame


Expand All @@ -18,35 +21,42 @@ def __init__(self):
self.name = "Event Type Classifier"
self.description = "Classifies the event types for the corresponding activity labels from a patient journey."

def execute(self, _input, patient_journey=None):
super().execute(_input, patient_journey)
self.result = self.__add_event_types(_input)
def execute(self, df, patient_journey=None):
super().execute(df, patient_journey)
self.result = self.__add_event_types(df)

def __add_event_types(self, df):
"""Adds event types to the activity labels."""
name = "event_type"
df[name] = df["event_information"].apply(self.__classify_event_type)
# document_intermediates(output)

def __add_event_types(self, activity_labels):
"""Adds event types to the bulletpoints."""
return df

@staticmethod
def __classify_event_type(activity_label):
messages = [
{"role": "system", "content": p.BULLETPOINTS_EVENT_TYPE_CONTEXT},
{"role": "system", "content": p.EVENT_TYPE_CONTEXT},
{
"role": "user",
"content": f"{p.EVENT_TYPE_PROMPT}\n The bulletpoint: {activity_label}",
},
{"role": "assistant", "content": p.EVENT_TYPE_ANSWER},
]
output = u.query_gpt(messages)
fc_message = [
{"role": "system", "content": p.FC_EVENT_TYPE_CONTEXT},
{
"role": "user",
"content": p.BULLETPOINTS_EVENT_TYPE_PROMPT + activity_labels,
"content": f"{p.FC_EVENT_TYPE_PROMPT} The text: {output}",
},
{"role": "assistant", "content": p.BULLETPOINTS_EVENT_TYPE_ANSWER},
]
activity_labels_with_event_types = u.query_gpt(messages)
activity_labels_with_event_types = self.__add_ending_commas(
activity_labels_with_event_types
event_type = u.query_gpt(
messages=fc_message,
tool_choice={
"type": "function",
"function": {"name": "add_event_type"},
},
)
with open(
(u.output_path / "intermediates/5_bulletpoints_with_event_type.txt"),
"w",
) as f:
f.write(activity_labels_with_event_types)
return activity_labels_with_event_types

# TODO: Remove when dataframes are used
@staticmethod
def __add_ending_commas(activity_labels):
"""Adds commas at the end of each line."""
activity_labels = activity_labels.replace("\n", ",\n")
activity_labels = activity_labels + ","
return activity_labels

return event_type
50 changes: 28 additions & 22 deletions tracex/extraction/logic/modules/module_location_extractor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pandas as pd

from ..module import Module
from .. import utils as u
from .. import prompts as p
Expand All @@ -16,33 +18,37 @@ def __init__(self):
self.name = "Location Extractor"
self.description = "Extracts the locations for the corresponding activity labels from a patient journey."

def execute(self, _input, patient_journey=None):
super().execute(_input, patient_journey)
self.result = self.__add_locations(_input)
def execute(self, df, patient_journey=None):
super().execute(df, patient_journey)
self.result = self.__add_locations(df)

def __add_locations(self, activity_labels):
def __add_locations(self, df):
"""Adds locations to the activity labels."""
name = "attribute_location"
df[name] = df["event_information"].apply(self.__classify_location)
# document_intermediates(output)

return df

@staticmethod
def __classify_location(activity_label):
messages = [
{"role": "system", "content": p.BULLETPOINTS_LOCATION_CONTEXT},
{"role": "system", "content": p.LOCATION_CONTEXT},
{"role": "user", "content": f"{p.LOCATION_PROMPT} {activity_label}"},
{"role": "assistant", "content": p.LOCATION_ANSWER},
]
output = u.query_gpt(messages)

fc_message = [
{"role": "system", "content": p.FC_LOCATION_CONTEXT},
{
"role": "user",
"content": p.BULLETPOINTS_LOCATION_PROMPT + activity_labels,
"content": f"{p.FC_LOCATION_PROMPT} The text: {output}",
},
{"role": "assistant", "content": p.BULLETPOINTS_LOCATION_ANSWER},
]
activity_labels_location = u.query_gpt(messages)
activity_labels_location = self.__remove_brackets(activity_labels_location)
with open(
(u.output_path / "intermediates/6_bulletpoints_with_location.txt"),
"w",
) as f:
f.write(activity_labels_location)
return activity_labels_location
location = u.query_gpt(
messages=fc_message,
tool_choice={"type": "function", "function": {"name": "add_location"}},
)

@staticmethod
def __remove_brackets(activity_labels):
"""Removes brackets from within the activity_labels."""
characters_to_remove = "()[]{}"
for char in characters_to_remove:
activity_labels = activity_labels.replace(char, "")
return activity_labels
return location
Loading

0 comments on commit 1a129a4

Please sign in to comment.