-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added log probs metrics and remove cot+fc of event type and location …
…prompts
- Loading branch information
Showing
7 changed files
with
180 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# pylint: skip-file | ||
# pylint: enable=wrong-import-position | ||
import os | ||
|
||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tracex.tracex.settings") | ||
|
||
from tracex.extraction.prototype import input_inquiry as ii | ||
from tracex.extraction.prototype import input_handling as ih | ||
from tracex.extraction.prototype import utils as u | ||
from tracex.extraction.prototype import function_calls as fc | ||
from tracex.extraction.prototype import metrics as m | ||
from tracex.extraction.prototype import create_xes as x | ||
|
||
text = open(u.input_path / "journey_synth_covid_0.txt").read() | ||
# df = ih.convert_text_to_bulletpoints(text) | ||
# print(df) | ||
|
||
df = m.measure_event_types(text) | ||
print(df) | ||
df = m.measure_location(text) | ||
print(df) | ||
ih.convert_dataframe_to_csv(df) | ||
# df = ih.add_start_dates(text, df) | ||
# print(df) | ||
# df = ih.add_end_dates(text, df) | ||
# print(df) | ||
# df = ih.add_durations(df) | ||
# print(df) | ||
# df = ih.add_event_types(df) | ||
# print(df) | ||
# df = ih.add_locations(df) | ||
# print(df) | ||
# ih.convert_dataframe_to_csv(df) | ||
# x.create_xes(u.output_path / "single_trace.csv", "test", "event_information") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
case_id,event_information,start_date,end_date,duration,event_type,attribute_location | ||
0,"experiencing first Covid-19 symptoms: mild cough, fatigue",20220601T0000,20220608T0000,168:00:00,Symptom Onset,Home | ||
0,brushing off symptoms as common cold,20220601T0000,20220608T0000,168:00:00,Other,Home | ||
0,developing high fever and difficulty breathing,20220601T0000,20220611T0000,240:00:00,Symptom Onset,Hospital | ||
0,deciding to get tested for Covid-19,20220601T0000,20220617T0000,384:00:00,Diagnosis,Doctors | ||
0,going to local testing center,20220617T0000,20220617T0000,00:00:00,Diagnosis,Hospital | ||
0,undergoing PCR test,20220617T0000,20220617T0000,00:00:00,Diagnosis,Home | ||
0,receiving negative test results,20220617T0000,20220619T0000,48:00:00,Diagnosis,Doctors | ||
0,getting infected and testing positive,20220617T0000,20220617T0000,00:00:00,Symptom Onset,Home | ||
case_id,event_information,event_type,location,"(token1, lin_prob1)","(token2, lin_prob2)" | ||
0,experiencing first Covid-19 symptoms in June 2022,Symptom Onset,Doctors,"('Doctors', 47.42)","('Home', 45.07)" | ||
0,brushing off symptoms as common cold,Lifestyle Change,Home,"('Home', 86.14)","('Doctors', 12.56)" | ||
0,developing high fever and difficulty breathing,Symptom Onset,Doctors,"('Doctors', 68.56)","('Hospital', 13.31)" | ||
0,getting tested for Covid-19,Diagnosis,Doctors,"('Doctors', 91.6)","('Hospital', 3.18)" | ||
0,visiting local testing center,Doctor visit,Doctors,"('Doctors', 99.36)","('Doctor', 0.27)" | ||
0,undergoing PCR test,Diagnosis,Doctors,"('Doctors', 82.7)","('Di', 7.96)" | ||
0,receiving negative test results,"The bulletpoint ""receiving negative test results"" can be classified as ""Symptom Offset"".",Doctors,"('Doctors', 55.6)","('Home', 39.8)" | ||
0,getting infected and testing positive,Diagnosis,Doctors,"('Doctors', 90.93)","('Hospital', 3.57)" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import pandas as pd | ||
import numpy as np | ||
|
||
from . import utils as u | ||
from . import prompts as p | ||
from . import input_handling as ih | ||
|
||
|
||
def measure_event_types(text): | ||
df = ih.convert_text_to_bulletpoints(text) | ||
new_df = pd.DataFrame([], columns=["event_type", "(token1, lin_prob1)", "(token2, lin_prob2)"]) | ||
values_list = df.values.tolist() | ||
for item in values_list: | ||
messages = [ | ||
{"role": "system", "content": p.EVENT_TYPE_CONTEXT}, | ||
{ | ||
"role": "user", | ||
"content": p.EVENT_TYPE_PROMPT + "\nThe bulletpoint: " + item[0], | ||
}, | ||
{"role": "assistant", "content": p.EVENT_TYPE_ANSWER}, | ||
] | ||
content, top_logprops = u.query_gpt(messages, logprobs=True, top_logprobs=2) | ||
metrics = [content] | ||
|
||
for logprob in top_logprops: | ||
token = logprob.token | ||
lin_prop = calculate_linear_probability(logprob.logprob) | ||
metrics.append((token, lin_prop)) | ||
|
||
new_row = pd.DataFrame([metrics], columns=["event_type", "(token1, lin_prob1)", "(token2, lin_prob2)"]) | ||
new_df = pd.concat([new_df, new_row], ignore_index=True) | ||
ih.document_intermediates(new_row.to_string()) | ||
print(new_row.to_string()) | ||
df = pd.concat([df, new_df], axis=1) | ||
return df | ||
|
||
def measure_location(text): | ||
df = ih.add_event_types(ih.convert_text_to_bulletpoints(text)) | ||
new_df = pd.DataFrame([], columns=["location", "(token1, lin_prob1)", "(token2, lin_prob2)"]) | ||
values_list = df.values.tolist() | ||
event_type_key = df.columns.get_loc("event_type") | ||
for item in values_list: | ||
messages = [ | ||
{"role": "system", "content": p.LOCATION_CONTEXT}, | ||
{ | ||
"role": "user", | ||
"content": p.LOCATION_PROMPT | ||
+ item[0] | ||
+ "\nThe category: " | ||
+ item[event_type_key], | ||
}, | ||
{"role": "assistant", "content": p.LOCATION_ANSWER}, | ||
] | ||
content, top_logprops = u.query_gpt(messages, logprobs=True, top_logprobs=2) | ||
metrics = [content] | ||
|
||
for logprob in top_logprops: | ||
token = logprob.token | ||
lin_prop = calculate_linear_probability(logprob.logprob) | ||
metrics.append((token, lin_prop)) | ||
|
||
new_row = pd.DataFrame([metrics], columns=["location", "(token1, lin_prob1)", "(token2, lin_prob2)"]) | ||
new_df = pd.concat([new_df, new_row], ignore_index=True) | ||
ih.document_intermediates(new_row.to_string()) | ||
df = pd.concat([df, new_df], axis=1) | ||
return df | ||
|
||
|
||
|
||
|
||
def calculate_linear_probability(logprob): | ||
linear_prob = np.round(np.exp(logprob) * 100, 2) | ||
return linear_prob |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters