Skip to content

Commit

Permalink
refactor: typing for views.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thangixd committed May 24, 2024
1 parent 58b5b17 commit fceb2b6
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions tracex_project/db_results/views.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""This file contains the views for the database result app."""
from typing import Tuple, List

import pandas as pd
import plotly.graph_objects as go

from db_results.forms import PatientJourneySelectForm, EvaluationForm
from django.db.models import Q
from django.db.models.query import QuerySet
from django.http import HttpRequest
from django.urls import reverse_lazy
from django.views.generic import FormView, TemplateView

from extraction.models import Trace, PatientJourney, Cohort
from plotly.offline import plot
from tracex.logic import utils as u
from tracex.logic.constants import ACTIVITY_KEYS, EVENT_TYPES, LOCATIONS
from tracex.views import DownloadXesView
from db_results.forms import PatientJourneySelectForm, EvaluationForm
from extraction.models import Trace, PatientJourney, Cohort


class DbResultsOverviewView(TemplateView):
Expand Down Expand Up @@ -65,7 +67,7 @@ def get_context_data(self, **kwargs):

return context

def get_latest_trace_df(self):
def get_latest_trace_df(self) -> pd.DataFrame:
"""
Fetch the DataFrame for the latest trace of a specific patient journey stored in the session.
Expand All @@ -88,7 +90,7 @@ def get_latest_trace_df(self):
)
return u.DataFrameUtilities.get_events_df(query_last_trace)

def update_context_with_counts(self, context, trace_df):
def update_context_with_counts(self, context, trace_df: pd.DataFrame):
"""Update the given context dictionary with count statistics related to patient journeys and traces."""
patient_journey_name = self.request.session["patient_journey_name"]

Expand All @@ -100,7 +102,7 @@ def update_context_with_counts(self, context, trace_df):
"traces_count": Trace.manager.filter(patient_journey__name=patient_journey_name).count()
})

def update_context_with_charts(self, context, trace_df):
def update_context_with_charts(self, context, trace_df: pd.DataFrame):
"""Update the context dictionary with chart visualizations."""
relevance_counts = trace_df["activity_relevance"].value_counts()
timestamp_correctness_counts = trace_df["timestamp_correctness"].value_counts()
Expand All @@ -118,7 +120,7 @@ def update_context_with_charts(self, context, trace_df):
"average_timestamp_correctness": round(trace_df["correctness_confidence"].mean(), 2)
})

def update_context_with_data_tables(self, context, trace_df):
def update_context_with_data_tables(self, context, trace_df: pd.DataFrame):
"""Format trace data into styled HTML tables and add them to the context."""

# Apply renaming, styling, and convert to HTML, then update the context
Expand All @@ -142,7 +144,7 @@ def update_context_with_data_tables(self, context, trace_df):
})

@staticmethod
def color_relevance(row):
def color_relevance(row: pd.Series) -> List[str]:
"""Apply background color styling to a DataFrame row based on the activity relevance."""
activity_relevance = row["Activity Relevance"]
if activity_relevance == "Moderate Relevance":
Expand All @@ -154,7 +156,7 @@ def color_relevance(row):
return [""] * len(row)

@staticmethod
def color_timestamp_correctness(row):
def color_timestamp_correctness(row: pd.Series) -> List[str]:
"""Apply background color styling to cells in a DataFrame row based on timestamp correctness and confidence."""
correctness_confidence = row["Correctness Confidence"]
confidence_index = row.index.get_loc("Correctness Confidence")
Expand All @@ -179,7 +181,7 @@ def color_timestamp_correctness(row):
return styles

@staticmethod
def create_pie_chart(data):
def create_pie_chart(data: pd.Series) -> str:
"""Create a pie chart from the provided data using Plotly."""
return plot(
go.Figure(
Expand All @@ -196,7 +198,7 @@ def create_pie_chart(data):
)

@staticmethod
def create_bar_chart(data, x_title, y_title):
def create_bar_chart(data: pd.Series, x_title: str, y_title: str) -> str:
"""Create a bar chart from the provided data using Plotly."""
return plot(
go.Figure(
Expand Down Expand Up @@ -252,7 +254,7 @@ def get_context_data(self, **kwargs):
self.request.session["event_log"] = event_log_df.to_json()
return context

def get_traces_and_events(self):
def get_traces_and_events(self) -> Tuple[QuerySet, pd.DataFrame]:
"""
Fetch trace data and corresponding event logs based on the current session configuration.
Expand All @@ -272,7 +274,7 @@ def get_traces_and_events(self):
traces = Trace.manager.all()
return traces, event_log_df

def get_cohorts_data(self, traces):
def get_cohorts_data(self, traces: QuerySet) -> pd.DataFrame:
"""Extract and format cohort data from given traces for further processing and visualization."""
cohorts = Cohort.manager.filter(trace__in=traces)
cohorts_data = list(cohorts.values("trace", "age", "sex", "origin", "condition", "preexisting_condition"))
Expand All @@ -281,7 +283,7 @@ def get_cohorts_data(self, traces):
cohorts_df["age"] = cohorts_df["age"].astype(pd.Int64Dtype())
return cohorts_df

def filter_and_cleanup_event_log(self, event_log_df, filter_settings):
def filter_and_cleanup_event_log(self, event_log_df: pd.DataFrame, filter_settings: dict) -> pd.DataFrame:
"""Apply user-defined filters to the event log data and clean up unnecessary columns."""
filter_dict = {
"event_type": filter_settings.get("event_types"),
Expand All @@ -292,7 +294,8 @@ def filter_and_cleanup_event_log(self, event_log_df, filter_settings):
columns=["activity_relevance", "timestamp_correctness", "correctness_confidence"])
return event_log_df

def generate_dfg_and_tables(self, event_log_df, cohorts_df, filter_settings):
def generate_dfg_and_tables(self, event_log_df: pd.DataFrame, cohorts_df: pd.DataFrame,
filter_settings: dict) -> dict:
"""Generate visualizations and HTML tables for the provided event log and cohort data."""
activity_key = filter_settings.get("activity_key")
return {
Expand Down Expand Up @@ -347,7 +350,7 @@ class DownloadXesEvaluationView(DownloadXesView):
"""View to download evaluation data in XES format based on the event log stored in the session."""

@staticmethod
def process_trace_type(request, trace_type):
def process_trace_type(request: HttpRequest, trace_type: str):
"""Process and provide the XES files to be downloaded based on the trace type."""
configuration = request.session.get("filter_settings")
activity_key = configuration.get("activity_key")
Expand Down

0 comments on commit fceb2b6

Please sign in to comment.