diff --git a/codes/ecg-image-generator/README.md b/codes/ecg-image-generator/README.md index 129dba8..74a247b 100644 --- a/codes/ecg-image-generator/README.md +++ b/codes/ecg-image-generator/README.md @@ -11,8 +11,8 @@ The process of scanning and digitizing ECG images is governed by some fundamenta ## Installation - Setup with Conda env: ``` - conda env create -f environment_droplet.yml - conda activate myenv + conda env create -f environment.yml + conda activate ecg-image-generator ``` - Set up with pip: diff --git a/codes/ecg-image-generator/TemplateFiles/generate_template.py b/codes/ecg-image-generator/TemplateFiles/generate_template.py index 22c2f64..abc47f0 100644 --- a/codes/ecg-image-generator/TemplateFiles/generate_template.py +++ b/codes/ecg-image-generator/TemplateFiles/generate_template.py @@ -8,29 +8,47 @@ test_date1 = date(1940, 1, 1) +import random +from datetime import datetime, timedelta + +def random_datetime(start_date, end_date): + """Generate random datetime between two dates""" + time_between = end_date - start_date + total_seconds = time_between.total_seconds() + random_seconds = random.uniform(0, total_seconds) + return start_date + timedelta(seconds=random_seconds) + def generate_template(header_file): filename, extn = os.path.splitext(header_file) fields = wfdb.rdheader(filename) + start = datetime(1960, 1, 1) + end = datetime(2022, 1, 1) + random_dt = random_datetime(start, end) + if fields.comments == []: attributes = {} if fields.base_date is not None: attributes['Date'] = fields.base_date else: - attributes['Date'] = "" + attributes['Date'] = random_dt.strftime(random.choice(["%d/%m/%Y", "%d-%m-%Y"])) if fields.base_time is not None: attributes['Time'] = str(fields.base_time) else: - attributes['Time'] = "" + attributes['Time'] = random_dt.strftime("%H:%M:%S") attributes['ID'] = 'ID: ' + filename.split('/')[-1] attributes['Name'] = 'Name: ' if attributes['Date'] != "": attributes['Date'] = 'Date:' + str(attributes['Date']) if attributes['Time'] != "": attributes['Date'] += ', ' + attributes['Time'] + attributes['Sex'] = f"Sex: {random.choice(['Male', 'Female'])}" + attributes['Age'] = f"Age: {random.choice(range(12, 75))} yrs" printedText = {} printedText[0] = ['ID', 'Name', 'Date'] + printedText[1] = ["Age"] + printedText[2] = ["Sex"] return printedText, attributes, 1 @@ -39,21 +57,22 @@ def generate_template(header_file): attributes = {} + if fields.base_date is not None: attributes['Date'] = fields.base_date else: - attributes['Date'] = "" + attributes['Date'] = random_dt.strftime(random.choice(["%d/%m/%Y", "%d-%m-%Y"])) if fields.base_time is not None: attributes['Time'] = str(fields.base_time) else: - attributes['Time'] = "" + attributes['Time'] = random_dt.strftime("%H:%M:%S") attributes['ID'] = 'ID: ' + filename.split('/')[-1] attributes['Name'] = 'Name: ' #+ str(str(random.randint(10**(8-1), (10**8)-1))) attributes['Height'] = '' attributes['Weight'] = '' - attributes['Sex'] = '' + attributes['Sex'] = random.choice(["Male", "Female"]) for c in comments: col = c.split(':')[0] diff --git a/codes/ecg-image-generator/ecg_plot.py b/codes/ecg-image-generator/ecg_plot.py index 2053e75..d1cadc9 100644 --- a/codes/ecg-image-generator/ecg_plot.py +++ b/codes/ecg-image-generator/ecg_plot.py @@ -1,232 +1,303 @@ +from io import BytesIO import os +import cv2 import numpy as np import random import matplotlib.pyplot as plt import matplotlib from matplotlib.ticker import AutoMinorLocator from TemplateFiles.generate_template import generate_template -from math import ceil +from math import ceil from PIL import Image import csv -standard_values = {'y_grid_size' : 0.5, - 'x_grid_size' : 0.2, - 'y_grid_inch' : 5/25.4, - 'x_grid_inch' : 5/25.4, - 'grid_line_width' : 0.5, - 'lead_name_offset' : 0.5, - 'lead_fontsize' : 11, - 'x_gap' : 1, - 'y_gap' : 0.5, - 'display_factor' : 1, - 'line_width': 0.75, - 'row_height' : 8, - 'dc_offset_length' : 0.2, - 'lead_length' : 3, - 'V1_length' : 12, - 'width' : 11, - 'height' : 8.5 - } - -standard_major_colors = {'colour1' : (0.4274,0.196,0.1843), #brown - 'colour2' : (1,0.796,0.866), #pink - 'colour3' : (0.0,0.0, 0.4), #blue - 'colour4' : (0,0.3,0.0), #green - 'colour5' : (1,0,0) #red - } - - -standard_minor_colors = {'colour1' : (0.5882,0.4196,0.3960), - 'colour2' : (0.996,0.9294,0.9725), - 'colour3' : (0.0,0, 0.7), - 'colour4' : (0,0.8,0.3), - 'colour5' : (0.996,0.8745,0.8588) - } - -papersize_values = {'A0' : (33.1,46.8), - 'A1' : (33.1,23.39), - 'A2' : (16.54,23.39), - 'A3' : (11.69,16.54), - 'A4' : (8.27,11.69), - 'letter' : (8.5,11) - } - - -def inches_to_dots(value,resolution): - return (value * resolution) - -#Function to plot raw ecg signal +standard_values = { + "y_grid_size": 0.5, + "x_grid_size": 0.2, + "y_grid_inch": 5 / 25.4, + "x_grid_inch": 5 / 25.4, + "grid_line_width": 0.5, + "lead_name_offset": 0.5, + "lead_fontsize": 11, + "x_gap": 1, + "y_gap": 0.5, + "display_factor": 1, + "line_width": 0.75, + "row_height": 8, + "dc_offset_length": 0.2, + "lead_length": 3, + "V1_length": 12, + "width": 11, + "height": 8.5, +} + +standard_major_colors = { + "colour1": (0.4274, 0.196, 0.1843), # brown + "colour2": (1, 0.796, 0.866), # pink + "colour3": (0.0, 0.0, 0.4), # blue + "colour4": (0, 0.3, 0.0), # green + "colour5": (1, 0, 0), # red +} + + +standard_minor_colors = { + "colour1": (0.5882, 0.4196, 0.3960), + "colour2": (0.996, 0.9294, 0.9725), + "colour3": (0.0, 0, 0.7), + "colour4": (0, 0.8, 0.3), + "colour5": (0.996, 0.8745, 0.8588), +} + +papersize_values = { + "A0": (33.1, 46.8), + "A1": (33.1, 23.39), + "A2": (16.54, 23.39), + "A3": (11.69, 16.54), + "A4": (8.27, 11.69), + "letter": (8.5, 11), +} + + +def inches_to_dots(value, resolution): + return value * resolution + + +def create_fig_ax( + width, + height, + resolution, + title, + y_min, + y_max, + x_min, + x_max, +) -> tuple[plt.Figure, plt.Axes]: + fig, ax = plt.subplots(figsize=(width, height), dpi=resolution) + fig.subplots_adjust(hspace=0, wspace=0, left=0, right=1, bottom=0, top=1) + fig.suptitle(title) + ax.set_ylim(y_min, y_max) + ax.set_xlim(x_min, x_max) + ax.tick_params(axis="x", colors="white") + ax.tick_params(axis="y", colors="white") + return fig, ax + +color_dict = { + 0: 'red', + 1: 'blue', + 2: 'green', + 3: 'orange', + 4: 'purple', + 5: 'brown', + 6: 'pink', + 7: 'gray', + 8: 'olive', + 9: 'cyan', + 10: 'magenta', + 11: 'yellow', + 12: 'black' +} + +# Function to plot raw ecg signal def ecg_plot( - ecg, - configs, - sample_rate, - columns, - rec_file_name, - output_dir, - resolution, - pad_inches, - lead_index, - full_mode, - store_text_bbox, - full_header_file, - units = '', - papersize = '', - x_gap = standard_values['x_gap'], - y_gap = standard_values['y_gap'], - display_factor = standard_values['display_factor'], - line_width = standard_values['line_width'], - title = '', - style = None, - row_height = standard_values['row_height'], - show_lead_name = True, - show_grid = False, - show_dc_pulse = False, - y_grid = 0, - x_grid = 0, - standard_colours = False, - bbox = False, - print_txt=False, - json_dict=dict(), - start_index=-1, - store_configs=0, - lead_length_in_seconds=10 - ): - #Inputs : - #ecg - Dictionary of ecg signal with lead names as keys - #sample_rate - Sampling rate of the ecg signal - #lead_index - Order of lead indices to be plotted - #columns - Number of columns to be plotted in each row - #x_gap - gap between paper x axis border and signal plot - #y_gap - gap between paper y axis border and signal plot - #line_width - Width of line tracing the ecg - #title - Title of figure - #style - Black and white or colour - #row_height - gap between corresponding ecg rows - #show_lead_name - Option to show lead names or skip - #show_dc_pulse - Option to show dc pulse - #show_grid - Turn grid on or off - - - #Initialize some params - #secs represents how many seconds of ecg are plotted - #leads represent number of leads in the ecg - #rows are calculated based on corresponding number of leads and number of columns + ecg, + configs, + sample_rate, + columns, + rec_file_name, + output_dir, + resolution, + pad_inches, + lead_index, + full_mode, + store_text_bbox, + full_header_file, + units="", + papersize="", + x_gap=standard_values["x_gap"], + y_gap=standard_values["y_gap"], + display_factor=standard_values["display_factor"], + line_width=standard_values["line_width"], + title="", + style=None, + row_height=standard_values["row_height"], + show_lead_name=True, + show_grid=False, + show_dc_pulse=False, + y_grid=0, + x_grid=0, + standard_colours=False, + bbox=False, + print_txt=False, + json_dict=dict(), + start_index=-1, + store_configs=0, + lead_length_in_seconds=10, +): + # Inputs : + # ecg - Dictionary of ecg signal with lead names as keys + # sample_rate - Sampling rate of the ecg signal + # lead_index - Order of lead indices to be plotted + # columns - Number of columns to be plotted in each row + # x_gap - gap between paper x axis border and signal plot + # y_gap - gap between paper y axis border and signal plot + # line_width - Width of line tracing the ecg + # title - Title of figure + # style - Black and white or colour + # row_height - gap between corresponding ecg rows + # show_lead_name - Option to show lead names or skip + # show_dc_pulse - Option to show dc pulse + # show_grid - Turn grid on or off + + # Initialize some params + # secs represents how many seconds of ecg are plotted + # leads represent number of leads in the ecg + # rows are calculated based on corresponding number of leads and number of columns matplotlib.use("Agg") - #check if the ecg dict is empty + # check if the ecg dict is empty if ecg == {}: - return + return secs = lead_length_in_seconds leads = len(lead_index) - rows = int(ceil(leads/columns)) + rows = int(ceil(leads / columns)) - if(full_mode!='None'): - rows+=1 - leads+=1 - - #Grid calibration - #Each big grid corresponds to 0.2 seconds and 0.5 mV - #To do: Select grid size in a better way - y_grid_size = standard_values['y_grid_size'] - x_grid_size = standard_values['x_grid_size'] - grid_line_width = standard_values['grid_line_width'] - lead_name_offset = standard_values['lead_name_offset'] - lead_fontsize = standard_values['lead_fontsize'] + if full_mode != "None": + rows += 1 + leads += 1 + # Grid calibration + # Each big grid corresponds to 0.2 seconds and 0.5 mV + # To do: Select grid size in a better way + y_grid_size = standard_values["y_grid_size"] + x_grid_size = standard_values["x_grid_size"] + grid_line_width = standard_values["grid_line_width"] + lead_name_offset = standard_values["lead_name_offset"] + lead_fontsize = standard_values["lead_fontsize"] - #Set max and min coordinates to mark grid. Offset x_max slightly (i.e by 1 column width) + # Set max and min coordinates to mark grid. Offset x_max slightly (i.e by 1 column width) - if papersize=='': - width = standard_values['width'] - height = standard_values['height'] + if papersize == "": + width = standard_values["width"] + height = standard_values["height"] else: width = papersize_values[papersize][1] height = papersize_values[papersize][0] - - y_grid = standard_values['y_grid_inch'] - x_grid = standard_values['x_grid_inch'] - y_grid_dots = y_grid*resolution - x_grid_dots = x_grid*resolution - - #row_height = height * y_grid_size/(y_grid*(rows+2)) - row_height = (height * y_grid_size/y_grid)/(rows+2) + + y_grid = standard_values["y_grid_inch"] + x_grid = standard_values["x_grid_inch"] + y_grid_dots = y_grid * resolution + x_grid_dots = x_grid * resolution + + # row_height = height * y_grid_size/(y_grid*(rows+2)) + row_height = (height * y_grid_size / y_grid) / (rows + 2) x_max = width * x_grid_size / x_grid x_min = 0 - x_gap = np.floor(((x_max - (columns*secs))/2)/0.2)*0.2 + x_gap = np.floor(((x_max - (columns * secs)) / 2) / 0.2) * 0.2 y_min = 0 - y_max = height * y_grid_size/y_grid + y_max = height * y_grid_size / y_grid - json_dict['width'] = int(width*resolution) - json_dict['height'] = int(height*resolution) - #Set figure and subplot sizes + json_dict["width"] = int(width * resolution) + json_dict["height"] = int(height * resolution) + # Set figure and subplot sizes fig, ax = plt.subplots(figsize=(width, height), dpi=resolution) - - fig.subplots_adjust( - hspace = 0, - wspace = 0, - left = 0, - right = 1, - bottom = 0, - top = 1 - ) + fig.subplots_adjust(hspace=0, wspace=0, left=0, right=1, bottom=0, top=1) fig.suptitle(title) - #Mark grid based on whether we want black and white or colour - - if (style == 'bw'): - color_major = (0.4,0.4,0.4) + fig_text, ax_text = create_fig_ax( + width, + height, + resolution, + title, + y_min, + y_max, + x_min, + x_max, + ) + fig_leads, ax_leads = create_fig_ax( + width, + height, + resolution, + title, + y_min, + y_max, + x_min, + x_max, + ) + fig_grid, ax_grid = create_fig_ax( + width, + height, + resolution, + title, + y_min, + y_max, + x_min, + x_max, + ) + + # Mark grid based on whether we want black and white or colour + + if style == "bw": + color_major = (0.4, 0.4, 0.4) color_minor = (0.75, 0.75, 0.75) - color_line = (0,0,0) - elif(standard_colours > 0): + color_line = (0, 0, 0) + elif standard_colours > 0: random_colour_index = standard_colours - color_major = standard_major_colors['colour'+str(random_colour_index)] - color_minor = standard_minor_colors['colour'+str(random_colour_index)] - grey_random_color = random.uniform(0,0.2) - color_line = (grey_random_color,grey_random_color,grey_random_color) + color_major = standard_major_colors["colour" + str(random_colour_index)] + color_minor = standard_minor_colors["colour" + str(random_colour_index)] + grey_random_color = random.uniform(0, 0.2) + color_line = (grey_random_color, grey_random_color, grey_random_color) else: - major_random_color_sampler_red = random.uniform(0,0.8) - major_random_color_sampler_green = random.uniform(0,0.5) - major_random_color_sampler_blue = random.uniform(0,0.5) + major_random_color_sampler_red = random.uniform(0, 0.8) + major_random_color_sampler_green = random.uniform(0, 0.5) + major_random_color_sampler_blue = random.uniform(0, 0.5) - minor_offset = random.uniform(0,0.2) + minor_offset = random.uniform(0, 0.2) minor_random_color_sampler_red = major_random_color_sampler_red + minor_offset - minor_random_color_sampler_green = random.uniform(0,0.5) + minor_offset - minor_random_color_sampler_blue = random.uniform(0,0.5) + minor_offset + minor_random_color_sampler_green = random.uniform(0, 0.5) + minor_offset + minor_random_color_sampler_blue = random.uniform(0, 0.5) + minor_offset + + grey_random_color = random.uniform(0, 0.2) + color_major = ( + major_random_color_sampler_red, + major_random_color_sampler_green, + major_random_color_sampler_blue, + ) + color_minor = ( + minor_random_color_sampler_red, + minor_random_color_sampler_green, + minor_random_color_sampler_blue, + ) - grey_random_color = random.uniform(0,0.2) - color_major = (major_random_color_sampler_red,major_random_color_sampler_green,major_random_color_sampler_blue) - color_minor = (minor_random_color_sampler_red,minor_random_color_sampler_green,minor_random_color_sampler_blue) - - color_line = (grey_random_color,grey_random_color,grey_random_color) + color_line = (grey_random_color, grey_random_color, grey_random_color) - #Set grid - #Standard ecg has grid size of 0.5 mV and 0.2 seconds. Set ticks accordingly - - ax.set_ylim(y_min,y_max) - ax.set_xlim(x_min,x_max) - ax.tick_params(axis='x', colors='white') - ax.tick_params(axis='y', colors='white') - - #Step size will be number of seconds per sample i.e 1/sampling_rate - step = (1.0/sample_rate) + # Set grid + # Standard ecg has grid size of 0.5 mV and 0.2 seconds. Set ticks accordingly + + ax.set_ylim(y_min, y_max) + ax.set_xlim(x_min, x_max) + ax.tick_params(axis="x", colors="white") + ax.tick_params(axis="y", colors="white") + + # Step size will be number of seconds per sample i.e 1/sampling_rate + step = 1.0 / sample_rate dc_offset = 0 - if(show_dc_pulse): - dc_offset = sample_rate*standard_values['dc_offset_length']*step - #Iterate through each lead in lead_index array. - y_offset = (row_height/2) + if show_dc_pulse: + dc_offset = sample_rate * standard_values["dc_offset_length"] * step + # Iterate through each lead in lead_index array. + y_offset = row_height / 2 x_offset = 0 leads_ds = [] - leadNames_12 = configs['leadNames_12'] - tickLength = configs['tickLength'] - tickSize_step = configs['tickSize_step'] + leadNames_12 = configs["leadNames_12"] + tickLength = configs["tickLength"] + tickSize_step = configs["tickSize_step"] for i in np.arange(len(lead_index)): current_lead_ds = dict() @@ -235,231 +306,324 @@ def ecg_plot( leadName = leadNames_12[i] else: leadName = lead_index[i] - #y_offset is computed by shifting by a certain offset based on i, and also by row_height/2 to account for half the waveform below the axis - if(i%columns==0): - + # y_offset is computed by shifting by a certain offset based on i, and also by row_height/2 to account for half the waveform below the axis + if i % columns == 0: y_offset += row_height - - #x_offset will be distance by which we shift the plot in each iteration - if(columns>1): - x_offset = (i%columns)*secs - + + # x_offset will be distance by which we shift the plot in each iteration + if columns > 1: + x_offset = (i % columns) * secs + else: x_offset = 0 - #Create dc pulse wave to plot at the beginning of plot. Dc pulse will be 0.2 seconds - x_range = np.arange(0,sample_rate*standard_values['dc_offset_length']*step + 4*step,step) + # Create dc pulse wave to plot at the beginning of plot. Dc pulse will be 0.2 seconds + x_range = np.arange( + 0, sample_rate * standard_values["dc_offset_length"] * step + 4 * step, step + ) dc_pulse = np.ones(len(x_range)) - dc_pulse = np.concatenate(((0,0),dc_pulse[2:-2],(0,0))) - - #Print lead name at .5 ( or 5 mm distance) from plot - if(show_lead_name): - t1 = ax.text(x_offset + x_gap + dc_offset, - y_offset-lead_name_offset - 0.2, - leadName, - fontsize=lead_fontsize) - - if (store_text_bbox): + dc_pulse = np.concatenate(((0, 0), dc_pulse[2:-2], (0, 0))) + + # Print lead name at .5 ( or 5 mm distance) from plot + if show_lead_name: + ax_text.text( + x_offset + x_gap + dc_offset, + y_offset - lead_name_offset - 0.2, + leadName, + fontsize=lead_fontsize, + # color=color_dict[i], + ) + t1 = ax.text( + x_offset + x_gap + dc_offset, + y_offset - lead_name_offset - 0.2, + leadName, + fontsize=lead_fontsize, + ) + + if store_text_bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() - bb = t1.get_window_extent() - x1 = bb.x0*resolution/fig.dpi - y1 = bb.y0*resolution/fig.dpi - x2 = bb.x1*resolution/fig.dpi - y2 = bb.y1*resolution/fig.dpi + bb = t1.get_window_extent() + x1 = bb.x0 * resolution / fig.dpi + y1 = bb.y0 * resolution / fig.dpi + x2 = bb.x1 * resolution / fig.dpi + y2 = bb.y1 * resolution / fig.dpi box_dict = dict() x1 = int(x1) y1 = int(y1) x2 = int(x2) y2 = int(y2) - box_dict[0] = [round(json_dict['height'] - y2, 2), round(x1, 2)] - box_dict[1] = [round(json_dict['height'] - y2, 2), round(x2, 2)] - box_dict[2] = [round(json_dict['height'] - y1, 2), round(x2, 2)] - box_dict[3] = [round(json_dict['height'] - y1, 2), round(x1, 2)] + box_dict[0] = [round(json_dict["height"] - y2, 2), round(x1, 2)] + box_dict[1] = [round(json_dict["height"] - y2, 2), round(x2, 2)] + box_dict[2] = [round(json_dict["height"] - y1, 2), round(x2, 2)] + box_dict[3] = [round(json_dict["height"] - y1, 2), round(x1, 2)] current_lead_ds["text_bounding_box"] = box_dict current_lead_ds["lead_name"] = leadName - #If we are plotting the first row-1 plots, we plot the dc pulse prior to adding the waveform - if(columns == 1 and i in np.arange(0,rows)): - if(show_dc_pulse): - #Plot dc pulse for 0.2 seconds with 2 trailing and leading zeros to get the pulse - t1 = ax.plot(x_range + x_offset + x_gap, - dc_pulse+y_offset, - linewidth=line_width * 1.5, - color=color_line - ) - if (bbox): + # If we are plotting the first row-1 plots, we plot the dc pulse prior to adding the waveform + if columns == 1 and i in np.arange(0, rows): + if show_dc_pulse: + # Plot dc pulse for 0.2 seconds with 2 trailing and leading zeros to get the pulse + ax_text.plot( + x_range + x_offset + x_gap, + dc_pulse + y_offset, + linewidth=line_width * 1.5, + # color=color_line, + ) + t1 = ax.plot( + x_range + x_offset + x_gap, + dc_pulse + y_offset, + linewidth=line_width * 1.5, + color=color_line, + ) + if bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() - bb = t1[0].get_window_extent() - x1, y1 = bb.x0*resolution/fig.dpi, bb.y0*resolution/fig.dpi - x2, y2 = bb.x1*resolution/fig.dpi, bb.y1*resolution/fig.dpi - - - elif(i%columns == 0): - if(show_dc_pulse): - #Plot dc pulse for 0.2 seconds with 2 trailing and leading zeros to get the pulse - t1 = ax.plot(np.arange(0,sample_rate*standard_values['dc_offset_length']*step + 4*step,step) + x_offset + x_gap, - dc_pulse+y_offset, - linewidth=line_width * 1.5, - color=color_line - ) - if (bbox): + bb = t1[0].get_window_extent() + x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi + x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi + + elif i % columns == 0: + if show_dc_pulse: + # Plot dc pulse for 0.2 seconds with 2 trailing and leading zeros to get the pulse + ax_text.plot( + np.arange( + 0, + sample_rate * standard_values["dc_offset_length"] * step + + 4 * step, + step, + ) + + x_offset + + x_gap, + dc_pulse + y_offset, + linewidth=line_width * 1.5, + # color=color_line, + ) + t1 = ax.plot( + np.arange( + 0, + sample_rate * standard_values["dc_offset_length"] * step + + 4 * step, + step, + ) + + x_offset + + x_gap, + dc_pulse + y_offset, + linewidth=line_width * 1.5, + color=color_line, + ) + if bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() - bb = t1[0].get_window_extent() - x1, y1 = bb.x0*resolution/fig.dpi, bb.y0*resolution/fig.dpi - x2, y2 = bb.x1*resolution/fig.dpi, bb.y1*resolution/fig.dpi - - t1 = ax.plot(np.arange(0,len(ecg[leadName])*step,step) + x_offset + dc_offset + x_gap, - ecg[leadName] + y_offset, - linewidth=line_width, - color=color_line - ) - - x_vals = np.arange(0,len(ecg[leadName])*step,step) + x_offset + dc_offset + x_gap + bb = t1[0].get_window_extent() + x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi + x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi + + ax_leads.plot( + np.arange(0, len(ecg[leadName]) * step, step) + + x_offset + + dc_offset + + x_gap, + ecg[leadName] + y_offset, + linewidth=line_width, + color="black", + ) + t1 = ax.plot( + np.arange(0, len(ecg[leadName]) * step, step) + + x_offset + + dc_offset + + x_gap, + ecg[leadName] + y_offset, + linewidth=line_width, + color=color_line, + ) + + x_vals = ( + np.arange(0, len(ecg[leadName]) * step, step) + x_offset + dc_offset + x_gap + ) y_vals = ecg[leadName] + y_offset - if (bbox): + if bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() - bb = t1[0].get_window_extent() - if show_dc_pulse == False or (columns == 4 and (i != 0 and i != 4 and i != 8)): - x1, y1 = bb.x0*resolution/fig.dpi, bb.y0*resolution/fig.dpi - x2, y2 = bb.x1*resolution/fig.dpi, bb.y1*resolution/fig.dpi + bb = t1[0].get_window_extent() + if show_dc_pulse == False or ( + columns == 4 and (i != 0 and i != 4 and i != 8) + ): + x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi + x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi else: - y1 = min(y1, bb.y0*resolution/fig.dpi) - y2 = max(y2, bb.y1*resolution/fig.dpi) - x2 = bb.x1*resolution/fig.dpi + y1 = min(y1, bb.y0 * resolution / fig.dpi) + y2 = max(y2, bb.y1 * resolution / fig.dpi) + x2 = bb.x1 * resolution / fig.dpi box_dict = dict() x1 = int(x1) y1 = int(y1) x2 = int(x2) y2 = int(y2) - box_dict[0] = [round(json_dict['height'] - y2, 2), round(x1, 2)] - box_dict[1] = [round(json_dict['height'] - y2, 2), round(x2, 2)] - box_dict[2] = [round(json_dict['height'] - y1, 2), round(x2, 2)] - box_dict[3] = [round(json_dict['height'] - y1, 2), round(x1, 2)] + box_dict[0] = [round(json_dict["height"] - y2, 2), round(x1, 2)] + box_dict[1] = [round(json_dict["height"] - y2, 2), round(x2, 2)] + box_dict[2] = [round(json_dict["height"] - y1, 2), round(x2, 2)] + box_dict[3] = [round(json_dict["height"] - y1, 2), round(x1, 2)] current_lead_ds["lead_bounding_box"] = box_dict - + st = start_index - if columns == 4 and leadName in configs['format_4_by_3'][1]: - st = start_index + int(sample_rate*configs['paper_len']/columns) - elif columns == 4 and leadName in configs['format_4_by_3'][2]: - st = start_index + int(2*sample_rate*configs['paper_len']/columns) - elif columns == 4 and leadName in configs['format_4_by_3'][3]: - st = start_index + int(3*sample_rate*configs['paper_len']/columns) + if columns == 4 and leadName in configs["format_4_by_3"][1]: + st = start_index + int(sample_rate * configs["paper_len"] / columns) + elif columns == 4 and leadName in configs["format_4_by_3"][2]: + st = start_index + int(2 * sample_rate * configs["paper_len"] / columns) + elif columns == 4 and leadName in configs["format_4_by_3"][3]: + st = start_index + int(3 * sample_rate * configs["paper_len"] / columns) current_lead_ds["start_sample"] = st - current_lead_ds["end_sample"]= st + len(ecg[leadName]) + current_lead_ds["end_sample"] = st + len(ecg[leadName]) current_lead_ds["plotted_pixels"] = [] for j in range(len(x_vals)): xi, yi = x_vals[j], y_vals[j] xi, yi = ax.transData.transform((xi, yi)) - yi = json_dict['height'] - yi - current_lead_ds['plotted_pixels'].append([round(yi, 2), round(xi, 2)]) + yi = json_dict["height"] - yi + current_lead_ds["plotted_pixels"].append([round(yi, 2), round(xi, 2)]) leads_ds.append(current_lead_ds) - if columns > 1 and (i+1)%columns != 0: - sep_x = [len(ecg[leadName])*step + x_offset + dc_offset + x_gap] * round(tickLength*y_grid_dots) + if columns > 1 and (i + 1) % columns != 0: + sep_x = [len(ecg[leadName]) * step + x_offset + dc_offset + x_gap] * round( + tickLength * y_grid_dots + ) sep_x = np.array(sep_x) - sep_y = np.linspace(y_offset - tickLength/2*y_grid_dots*tickSize_step, y_offset + tickSize_step*y_grid_dots*tickLength/2, len(sep_x)) + sep_y = np.linspace( + y_offset - tickLength / 2 * y_grid_dots * tickSize_step, + y_offset + tickSize_step * y_grid_dots * tickLength / 2, + len(sep_x), + ) + ax_text.plot(sep_x, sep_y, linewidth=line_width * 3, color=color_line) ax.plot(sep_x, sep_y, linewidth=line_width * 3, color=color_line) - #Plotting longest lead for 12 seconds - if(full_mode!='None'): + # Plotting longest lead for 12 seconds + if full_mode != "None": current_lead_ds = dict() - if(show_lead_name): - t1 = ax.text(x_gap + dc_offset, - row_height/2-lead_name_offset, - full_mode, - fontsize=lead_fontsize) - - if (store_text_bbox): + if show_lead_name: + ax_text.text( + x_gap + dc_offset, + row_height / 2 - lead_name_offset, + full_mode, + fontsize=lead_fontsize, + # color=color_dict[12], + ) + t1 = ax.text( + x_gap + dc_offset, + row_height / 2 - lead_name_offset, + full_mode, + fontsize=lead_fontsize, + ) + + if store_text_bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() - bb = t1.get_window_extent(renderer = fig.canvas.renderer) - x1 = bb.x0*resolution/fig.dpi - y1 = bb.y0*resolution/fig.dpi - x2 = bb.x1*resolution/fig.dpi - y2 = bb.y1*resolution/fig.dpi + bb = t1.get_window_extent(renderer=fig.canvas.renderer) + x1 = bb.x0 * resolution / fig.dpi + y1 = bb.y0 * resolution / fig.dpi + x2 = bb.x1 * resolution / fig.dpi + y2 = bb.y1 * resolution / fig.dpi box_dict = dict() x1 = int(x1) y1 = int(y1) x2 = int(x2) y2 = int(y2) - box_dict[0] = [round(json_dict['height'] - y2, 2), round(x1, 2)] - box_dict[1] = [round(json_dict['height'] - y2, 2), round(x2, 2)] - box_dict[2] = [round(json_dict['height'] - y1, 2), round(x2, 2)] - box_dict[3] = [round(json_dict['height'] - y1), round(x1, 2)] - current_lead_ds["text_bounding_box"] = box_dict + box_dict[0] = [round(json_dict["height"] - y2, 2), round(x1, 2)] + box_dict[1] = [round(json_dict["height"] - y2, 2), round(x2, 2)] + box_dict[2] = [round(json_dict["height"] - y1, 2), round(x2, 2)] + box_dict[3] = [round(json_dict["height"] - y1), round(x1, 2)] + current_lead_ds["text_bounding_box"] = box_dict current_lead_ds["lead_name"] = full_mode - if(show_dc_pulse): - t1 = ax.plot(x_range + x_gap, - dc_pulse + row_height/2-lead_name_offset + 0.8, - linewidth=line_width * 1.5, - color=color_line - ) - - if (bbox): - renderer1 = fig.canvas.get_renderer() - transf = ax.transData.inverted() - bb = t1[0].get_window_extent() - x1, y1 = bb.x0*resolution/fig.dpi, bb.y0*resolution/fig.dpi - x2, y2 = bb.x1*resolution/fig.dpi, bb.y1*resolution/fig.dpi - - dc_full_lead_offset = 0 - if(show_dc_pulse): - dc_full_lead_offset = sample_rate*standard_values['dc_offset_length']*step - - t1 = ax.plot(np.arange(0,len(ecg['full'+full_mode])*step,step) + x_gap + dc_full_lead_offset, - ecg['full'+full_mode] + row_height/2-lead_name_offset + 0.8, - linewidth=line_width, - color=color_line - ) - x_vals = np.arange(0,len(ecg['full'+full_mode])*step,step) + x_gap + dc_full_lead_offset - y_vals = ecg['full'+full_mode] + row_height/2-lead_name_offset + 0.8 + if show_dc_pulse: + ax_text.plot( + x_range + x_gap, + dc_pulse + row_height / 2 - lead_name_offset + 0.8, + linewidth=line_width * 1.5, + # color=color_line, + ) + t1 = ax.plot( + x_range + x_gap, + dc_pulse + row_height / 2 - lead_name_offset + 0.8, + linewidth=line_width * 1.5, + color=color_line, + ) + + if bbox: + renderer1 = fig.canvas.get_renderer() + transf = ax.transData.inverted() + bb = t1[0].get_window_extent() + x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi + x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi + + dc_full_lead_offset = 0 + if show_dc_pulse: + dc_full_lead_offset = ( + sample_rate * standard_values["dc_offset_length"] * step + ) + + ax_leads.plot( + np.arange(0, len(ecg["full" + full_mode]) * step, step) + + x_gap + + dc_full_lead_offset, + ecg["full" + full_mode] + row_height / 2 - lead_name_offset + 0.8, + linewidth=line_width, + color="black", + ) + + t1 = ax.plot( + np.arange(0, len(ecg["full" + full_mode]) * step, step) + + x_gap + + dc_full_lead_offset, + ecg["full" + full_mode] + row_height / 2 - lead_name_offset + 0.8, + linewidth=line_width, + color=color_line, + ) + x_vals = ( + np.arange(0, len(ecg["full" + full_mode]) * step, step) + + x_gap + + dc_full_lead_offset + ) + y_vals = ecg["full" + full_mode] + row_height / 2 - lead_name_offset + 0.8 - if (bbox): + if bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() - bb = t1[0].get_window_extent() - if show_dc_pulse == False: - x1, y1 = bb.x0*resolution/fig.dpi, bb.y0*resolution/fig.dpi - x2, y2 = bb.x1*resolution/fig.dpi, bb.y1*resolution/fig.dpi + bb = t1[0].get_window_extent() + if show_dc_pulse == False: + x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi + x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi else: - y1 = min(y1, bb.y0*resolution/fig.dpi) - y2 = max(y2, bb.y1*resolution/fig.dpi) - x2 = bb.x1*resolution/fig.dpi + y1 = min(y1, bb.y0 * resolution / fig.dpi) + y2 = max(y2, bb.y1 * resolution / fig.dpi) + x2 = bb.x1 * resolution / fig.dpi box_dict = dict() x1 = int(x1) y1 = int(y1) x2 = int(x2) y2 = int(y2) - box_dict[0] = [round(json_dict['height'] - y2, 2), round(x1, 2)] - box_dict[1] = [round(json_dict['height'] - y2), round(x2, 2)] - box_dict[2] = [round(json_dict['height'] - y1, 2), round(x2, 2)] - box_dict[3] = [round(json_dict['height'] - y1, 2), round(x1, 2)] + box_dict[0] = [round(json_dict["height"] - y2, 2), round(x1, 2)] + box_dict[1] = [round(json_dict["height"] - y2, 2), round(x2, 2)] + box_dict[2] = [round(json_dict["height"] - y1, 2), round(x2, 2)] + box_dict[3] = [round(json_dict["height"] - y1, 2), round(x1, 2)] current_lead_ds["lead_bounding_box"] = box_dict current_lead_ds["start_sample"] = start_index - current_lead_ds["end_sample"] = start_index + len(ecg['full'+full_mode]) - current_lead_ds['plotted_pixels'] = [] + current_lead_ds["end_sample"] = start_index + len(ecg["full" + full_mode]) + current_lead_ds["plotted_pixels"] = [] for i in range(len(x_vals)): xi, yi = x_vals[i], y_vals[i] xi, yi = ax.transData.transform((xi, yi)) - yi = json_dict['height'] - yi - current_lead_ds['plotted_pixels'].append([round(yi, 2), round(xi, 2)]) + yi = json_dict["height"] - yi + current_lead_ds["plotted_pixels"].append([round(yi, 2), round(xi, 2)]) leads_ds.append(current_lead_ds) - - head, tail = os.path.split(rec_file_name) rec_file_name = os.path.join(output_dir, tail) - #printed template file + # printed template file if print_txt: x_offset = 0.05 y_offset = int(y_max) @@ -467,11 +631,11 @@ def ecg_plot( if flag: for l in range(0, len(printed_text), 1): - for j in printed_text[l]: - curr_l = '' + curr_l = "" if j in attributes.keys(): curr_l += str(attributes[j]) + ax_text.text(x_offset, y_offset, curr_l, fontsize=lead_fontsize) ax.text(x_offset, y_offset, curr_l, fontsize=lead_fontsize) x_offset += 3 @@ -482,38 +646,47 @@ def ecg_plot( ax.text(x_offset, y_offset, line, fontsize=lead_fontsize) y_offset -= 0.5 - #change x and y res - ax.text(2, 0.5, '25mm/s', fontsize=lead_fontsize) - ax.text(4, 0.5, '10mm/mV', fontsize=lead_fontsize) - - if(show_grid): - ax.set_xticks(np.arange(x_min,x_max,x_grid_size)) - ax.set_yticks(np.arange(y_min,y_max,y_grid_size)) + # change x and y res + ax.text(2, 0.5, "25mm/s", fontsize=lead_fontsize) + ax.text(4, 0.5, "10mm/mV", fontsize=lead_fontsize) + ax_text.text(2, 0.5, "25mm/s", fontsize=lead_fontsize) + ax_text.text(4, 0.5, "10mm/mV", fontsize=lead_fontsize) + + if show_grid: + ax.set_xticks(np.arange(x_min, x_max, x_grid_size)) + ax.set_yticks(np.arange(y_min, y_max, y_grid_size)) ax.minorticks_on() - + ax.xaxis.set_minor_locator(AutoMinorLocator(5)) - #set grid line style - ax.grid(which='major', linestyle='-', linewidth=grid_line_width, color=color_major) - - ax.grid(which='minor', linestyle='-', linewidth=grid_line_width, color=color_minor) - + # set grid line style + ax.grid( + which="major", linestyle="-", linewidth=grid_line_width, color=color_major + ) + + ax.grid( + which="minor", linestyle="-", linewidth=grid_line_width, color=color_minor + ) + if store_configs == 2: - json_dict['grid_line_color_major'] = [round(x*255., 2) for x in color_major] - json_dict['grid_line_color_minor'] = [round(x*255., 2) for x in color_minor] - json_dict['ecg_plot_color'] = [round(x*255., 2) for x in color_line] + json_dict["grid_line_color_major"] = [ + round(x * 255.0, 2) for x in color_major + ] + json_dict["grid_line_color_minor"] = [ + round(x * 255.0, 2) for x in color_minor + ] + json_dict["ecg_plot_color"] = [round(x * 255.0, 2) for x in color_line] else: ax.grid(False) - plt.savefig(os.path.join(output_dir,tail +'.png'),dpi=resolution) + fig.savefig(os.path.join(output_dir, tail + ".png"), dpi=resolution) plt.close(fig) - plt.clf() - plt.cla() + # plt.clf() + # plt.cla() + + if pad_inches != 0: + ecg_image = Image.open(os.path.join(output_dir, tail + ".png")) - if pad_inches!=0: - - ecg_image = Image.open(os.path.join(output_dir,tail +'.png')) - right = pad_inches * resolution left = pad_inches * resolution top = pad_inches * resolution @@ -521,17 +694,69 @@ def ecg_plot( width, height = ecg_image.size new_width = width + right + left new_height = height + top + bottom - result_image = Image.new(ecg_image.mode, (new_width, new_height), (255, 255, 255)) + result_image = Image.new( + ecg_image.mode, (new_width, new_height), (255, 255, 255) + ) result_image.paste(ecg_image, (left, top)) - - result_image.save(os.path.join(output_dir,tail +'.png')) - plt.close('all') + result_image.save(os.path.join(output_dir, tail + ".png")) + + plt.close("all") plt.close(fig) plt.clf() plt.cla() + + ax_grid.set_xticks(np.arange(x_min, x_max, x_grid_size)) + ax_grid.set_yticks(np.arange(y_min, y_max, y_grid_size)) + ax_grid.grid(which="major", linestyle="-", linewidth=grid_line_width, color=color_major) + + # ax.grid(which="minor", linestyle="-", linewidth=grid_line_width, color=color_minor) + buf = BytesIO() + fig_grid.savefig( + buf, + dpi=resolution, + ) + plt.close(fig) + buf.seek(0) + + img = Image.open(buf).convert("RGB") + grid_image_matrix = np.array(img).mean(axis=2).astype(np.uint8) + # cv2.imwrite( + # os.path.join(output_dir, tail + "_grid.png"), + # image_matrix.mean(axis=2).astype(np.uint8), + # ) + # fig_text.savefig(os.path.join(output_dir, tail + "_text.png")) + + + buf = BytesIO() + fig_leads.savefig(buf, dpi=resolution) + plt.close(fig_leads) + buf.seek(0) + img = Image.open(buf).convert("RGB") + leads_image_matrix = np.array(img).mean(axis=2).astype(np.uint8) + # cv2.imwrite( + # os.path.join(output_dir, tail + "_leads.png"), + # image_matrix.mean(axis=2).astype(np.uint8), + # ) + + + + buf = BytesIO() + fig_text.savefig(buf, dpi=resolution) + plt.close(fig_text) + buf.seek(0) + img = Image.open(buf).convert("RGB") + text_image_matrix = np.array(img).mean(axis=2).astype(np.uint8) + # cv2.imwrite( + # os.path.join(output_dir, tail + "_text.png"), + # image_matrix.mean(axis=2).astype(np.uint8), + # ) + mask = np.stack((grid_image_matrix, text_image_matrix, leads_image_matrix), axis=2) + cv2.imwrite( + os.path.join(output_dir, tail + "_mask.png"), + 255 - mask, + ) json_dict["leads"] = leads_ds - return x_grid_dots,y_grid_dots - \ No newline at end of file + return x_grid_dots, y_grid_dots diff --git a/codes/ecg-image-generator/environment_droplet.yml b/codes/ecg-image-generator/environment.yml similarity index 82% rename from codes/ecg-image-generator/environment_droplet.yml rename to codes/ecg-image-generator/environment.yml index b61ac4e..7dc2406 100644 --- a/codes/ecg-image-generator/environment_droplet.yml +++ b/codes/ecg-image-generator/environment.yml @@ -1,4 +1,4 @@ -name: myenv +name: ecg-image-generator channels: - defaults dependencies: @@ -24,4 +24,5 @@ dependencies: - seaborn==0.12.2 - validators==0.18.2 - spacy==3.0.8 - - https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_sm-0.4.0.tar.gz + - pyyaml==6.0 + - qrcode==8.2 diff --git a/codes/ecg-image-generator/extract_leads.py b/codes/ecg-image-generator/extract_leads.py index 2bd7274..9cf1547 100644 --- a/codes/ecg-image-generator/extract_leads.py +++ b/codes/ecg-image-generator/extract_leads.py @@ -7,17 +7,54 @@ import matplotlib.pyplot as plt from matplotlib.ticker import AutoMinorLocator from TemplateFiles.generate_template import generate_template -from math import ceil -from helper_functions import get_adc_gains,get_frequency,get_leads,load_recording,load_header,find_files, truncate_signal, create_signal_dictionary, standardize_leads, write_wfdb_file +from math import ceil +from helper_functions import ( + get_adc_gains, + get_frequency, + get_leads, + load_recording, + load_header, + find_files, + truncate_signal, + create_signal_dictionary, + standardize_leads, + write_wfdb_file, +) from ecg_plot import ecg_plot import wfdb from PIL import Image, ImageDraw, ImageFont from random import randint import random -# Run script. -def get_paper_ecg(input_file,header_file,output_directory, seed, add_dc_pulse,add_bw,show_grid, add_print, configs, mask_unplotted_samples = False, start_index = -1, store_configs=False, store_text_bbox=True,key='val',resolution=100,units='inches',papersize='',add_lead_names=True,pad_inches=1,template_file=os.path.join('TemplateFiles','TextFile1.txt'),font_type=os.path.join('Fonts','Times_New_Roman.ttf'),standard_colours=5,full_mode='II',bbox = False,columns=-1): +# Run script. +def get_paper_ecg( + input_file, + header_file, + output_directory, + seed, + add_dc_pulse, + add_bw, + show_grid, + add_print, + configs, + mask_unplotted_samples=False, + start_index=-1, + store_configs=False, + store_text_bbox=True, + key="val", + resolution=100, + units="inches", + papersize="", + add_lead_names=True, + pad_inches=1, + template_file=os.path.join("TemplateFiles", "TextFile1.txt"), + font_type=os.path.join("Fonts", "Times_New_Roman.ttf"), + standard_colours=5, + full_mode="II", + bbox=False, + columns=-1, +): # Extract a reduced-lead set from each pair of full-lead header and recording files. full_header_file = header_file full_recording_file = input_file @@ -26,7 +63,7 @@ def get_paper_ecg(input_file,header_file,output_directory, seed, add_dc_pulse,ad num_full_leads = len(full_leads) # Update the header file - full_lines = full_header.split('\n') + full_lines = full_header.split("\n") # For the first line, update the number of leads. entries = full_lines[0].split() @@ -34,140 +71,174 @@ def get_paper_ecg(input_file,header_file,output_directory, seed, add_dc_pulse,ad head, tail = os.path.split(full_header_file) output_header_file = os.path.join(output_directory, tail) - with open(output_header_file, 'w') as f: - f.write('\n'.join(full_lines)) + with open(output_header_file, "w") as f: + f.write("\n".join(full_lines)) - #Load the full-lead recording file, extract the lead data, and save the reduced-lead recording file. - recording = load_recording(full_recording_file, full_header,key) + # Load the full-lead recording file, extract the lead data, and save the reduced-lead recording file. + recording = load_recording(full_recording_file, full_header, key) # Get values from header rate = get_frequency(full_header) - adc = get_adc_gains(full_header,full_leads) - + adc = get_adc_gains(full_header, full_leads) + full_leads = standardize_leads(full_leads) - if(len(full_leads)==2): - full_mode = 'None' + if len(full_leads) == 2: + full_mode = "None" gen_m = 2 - if(columns==-1): + if columns == -1: columns = 1 - - elif(len(full_leads)==12): + + elif len(full_leads) == 12: gen_m = 12 if full_mode not in full_leads: full_mode = full_leads[0] else: full_mode = full_mode - if(columns==-1): + if columns == -1: columns = 4 else: gen_m = len(full_leads) columns = 4 - full_mode = 'None' + full_mode = "None" - template_name = 'custom_template.png' + template_name = "custom_template.png" - if(recording.shape[0] != num_full_leads): + if recording.shape[0] != num_full_leads: recording = np.transpose(recording) - record_dict = create_signal_dictionary(recording,full_leads) - + record_dict = create_signal_dictionary(recording, full_leads) + gain_index = 0 ecg_frame = [] end_flag = False start = 0 - lead_length_in_seconds = configs['paper_len']/columns - abs_lead_step = configs['abs_lead_step'] - format_4_by_3 = configs['format_4_by_3'] - + lead_length_in_seconds = configs["paper_len"] / columns + abs_lead_step = configs["abs_lead_step"] + format_4_by_3 = configs["format_4_by_3"] + segmented_ecg_data = {} if start_index != -1: start = start_index - #do something + # do something frame = {} gain_index = 0 for key in record_dict: - if(len(record_dict[key][start:])int(rate*10)): - frame['full'+full_mode] = record_dict[key][start:(start+int(rate)*10)] - frame['full'+full_mode] = frame['full'+full_mode] - if 'full'+full_mode not in segmented_ecg_data.keys(): - segmented_ecg_data['full'+full_mode] = frame['full'+full_mode].tolist() + nanArray[:] = record_dict[key][end : end + nanArray_len] + segmented_ecg_data[key] = ( + segmented_ecg_data[key] + nanArray.tolist() + ) + if full_mode != "None" and key == full_mode: + if len(record_dict[key][start:]) > int(rate * 10): + frame["full" + full_mode] = record_dict[key][ + start : (start + int(rate) * 10) + ] + frame["full" + full_mode] = frame["full" + full_mode] + if "full" + full_mode not in segmented_ecg_data.keys(): + segmented_ecg_data["full" + full_mode] = frame[ + "full" + full_mode + ].tolist() else: - segmented_ecg_data['full'+full_mode] = segmented_ecg_data['full'+full_mode] + frame['full'+full_mode].tolist() + segmented_ecg_data["full" + full_mode] = ( + segmented_ecg_data["full" + full_mode] + + frame["full" + full_mode].tolist() + ) else: - frame['full'+full_mode] = record_dict[key][start:] - frame['full'+full_mode] = frame['full'+full_mode] - if 'full'+full_mode not in segmented_ecg_data.keys(): - segmented_ecg_data['full'+full_mode] = frame['full'+full_mode].tolist() + frame["full" + full_mode] = record_dict[key][start:] + frame["full" + full_mode] = frame["full" + full_mode] + if "full" + full_mode not in segmented_ecg_data.keys(): + segmented_ecg_data["full" + full_mode] = frame[ + "full" + full_mode + ].tolist() else: - segmented_ecg_data['full'+full_mode] = segmented_ecg_data['full'+full_mode] + frame['full'+full_mode].tolist() + segmented_ecg_data["full" + full_mode] = ( + segmented_ecg_data["full" + full_mode] + + frame["full" + full_mode].tolist() + ) gain_index += 1 ecg_frame.append(frame) else: - while(end_flag==False): + while end_flag == False: # To do : Incorporate column and ful_mode info frame = {} gain_index = 0 - + for key in record_dict: - if(len(record_dict[key][start:])int(rate*10)): - frame['full'+full_mode] = record_dict[key][start:(start+int(rate)*10)] - frame['full'+full_mode] = frame['full'+full_mode] - if 'full'+full_mode not in segmented_ecg_data.keys(): - segmented_ecg_data['full'+full_mode] = frame['full'+full_mode].tolist() + nanArray[:] = record_dict[key][end : end + nanArray_len] + + segmented_ecg_data[key] = ( + segmented_ecg_data[key] + nanArray.tolist() + ) + if full_mode != "None" and key == full_mode: + if len(record_dict[key][start:]) > int(rate * 10): + frame["full" + full_mode] = record_dict[key][ + start : (start + int(rate) * 10) + ] + frame["full" + full_mode] = frame["full" + full_mode] + if "full" + full_mode not in segmented_ecg_data.keys(): + segmented_ecg_data["full" + full_mode] = frame[ + "full" + full_mode + ].tolist() else: - segmented_ecg_data['full'+full_mode] = segmented_ecg_data['full'+full_mode] + frame['full'+full_mode].tolist() + segmented_ecg_data["full" + full_mode] = ( + segmented_ecg_data["full" + full_mode] + + frame["full" + full_mode].tolist() + ) else: - frame['full'+full_mode] = record_dict[key][start:] - frame['full'+full_mode] = frame['full'+full_mode] - if 'full'+full_mode not in segmented_ecg_data.keys(): - segmented_ecg_data['full'+full_mode] = frame['full'+full_mode].tolist() + frame["full" + full_mode] = record_dict[key][start:] + frame["full" + full_mode] = frame["full" + full_mode] + if "full" + full_mode not in segmented_ecg_data.keys(): + segmented_ecg_data["full" + full_mode] = frame[ + "full" + full_mode + ].tolist() else: - segmented_ecg_data['full'+full_mode] = segmented_ecg_data['full'+full_mode] + frame['full'+full_mode].tolist() + segmented_ecg_data["full" + full_mode] = ( + segmented_ecg_data["full" + full_mode] + + frame["full" + full_mode].tolist() + ) gain_index += 1 - if(end_flag==False): + if end_flag == False: ecg_frame.append(frame) - start = start + int(rate*abs_lead_step) + start = start + int(rate * abs_lead_step) outfile_array = [] - + name, ext = os.path.splitext(full_header_file) - write_wfdb_file(segmented_ecg_data, name, rate, header_file, output_directory, full_mode, mask_unplotted_samples) + write_wfdb_file( + segmented_ecg_data, + name, + rate, + header_file, + output_directory, + full_mode, + mask_unplotted_samples, + ) if len(ecg_frame) == 0: return outfile_array @@ -258,22 +372,47 @@ def get_paper_ecg(input_file,header_file,output_directory, seed, add_dc_pulse,ad print_txt = add_print.rvs() json_dict = {} - json_dict['sampling_frequency'] = rate - grid_colour = 'colour' - if(bw): - grid_colour = 'bw' + json_dict["sampling_frequency"] = rate + grid_colour = "colour" + if bw: + grid_colour = "bw" - rec_file = name + '-' + str(i) + rec_file = name + "-" + str(i) if ecg_frame[i] == {}: continue - x_grid,y_grid = ecg_plot(ecg_frame[i], configs=configs, full_header_file=full_header_file, style=grid_colour, sample_rate = rate,columns=columns,rec_file_name = rec_file, output_dir = output_directory, resolution = resolution, pad_inches = pad_inches, lead_index=full_leads, full_mode = full_mode, store_text_bbox = store_text_bbox, show_lead_name=add_lead_names,show_dc_pulse=dc,papersize=papersize,show_grid=(grid),standard_colours=standard_colours,bbox=bbox, print_txt=print_txt, json_dict=json_dict, start_index=start, store_configs=store_configs, lead_length_in_seconds=lead_length_in_seconds) + x_grid, y_grid = ecg_plot( + ecg_frame[i], + configs=configs, + full_header_file=full_header_file, + style=grid_colour, + sample_rate=rate, + columns=columns, + rec_file_name=rec_file, + output_dir=output_directory, + resolution=resolution, + pad_inches=pad_inches, + lead_index=full_leads, + full_mode=full_mode, + store_text_bbox=store_text_bbox, + show_lead_name=add_lead_names, + show_dc_pulse=dc, + papersize=papersize, + show_grid=(grid), + standard_colours=standard_colours, + bbox=bbox, + print_txt=print_txt, + json_dict=json_dict, + start_index=start, + store_configs=store_configs, + lead_length_in_seconds=lead_length_in_seconds, + ) rec_head, rec_tail = os.path.split(rec_file) - + json_dict["x_grid"] = round(x_grid, 3) json_dict["y_grid"] = round(y_grid, 3) - json_dict["resolution"] =resolution + json_dict["resolution"] = resolution json_dict["pad_inches"] = pad_inches if store_configs == 2: @@ -282,17 +421,17 @@ def get_paper_ecg(input_file,header_file,output_directory, seed, add_dc_pulse,ad json_dict["gridlines"] = bool(grid) json_dict["printed_text"] = bool(print_txt) json_dict["number_of_columns_in_image"] = columns - json_dict["full_mode_lead"] =full_mode + json_dict["full_mode_lead"] = full_mode - outfile = os.path.join(output_directory,rec_tail+'.png') + outfile = os.path.join(output_directory, rec_tail + ".png") json_object = json.dumps(json_dict, indent=4) # Writing to sample.json if store_configs: - with open(os.path.join(output_directory,rec_tail+'.json'), "w") as f: + with open(os.path.join(output_directory, rec_tail + ".json"), "w") as f: f.write(json_object) outfile_array.append(outfile) - start += int(rate*abs_lead_step) - return outfile_array \ No newline at end of file + start += int(rate * abs_lead_step) + return outfile_array diff --git a/codes/ecg-image-generator/gen_ecg_images_from_data_batch.py b/codes/ecg-image-generator/gen_ecg_images_from_data_batch.py index df0efef..a6b45fc 100644 --- a/codes/ecg-image-generator/gen_ecg_images_from_data_batch.py +++ b/codes/ecg-image-generator/gen_ecg_images_from_data_batch.py @@ -1,114 +1,153 @@ -import os, sys, argparse +from copy import deepcopy +import multiprocessing +import os +import sys +import argparse import random import csv from helper_functions import find_records from gen_ecg_image_from_data import run_single_file import warnings +from tqdm import tqdm -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" warnings.filterwarnings("ignore") + def get_parser(): parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input_directory', type=str, required=True) - parser.add_argument('-o', '--output_directory', type=str, required=True) - parser.add_argument('-se', '--seed', type=int, required=False, default = -1) - parser.add_argument('--num_leads',type=str,default='twelve') - parser.add_argument('--max_num_images',type=int,default = -1) - parser.add_argument('--config_file', type=str, default='config.yaml') - - parser.add_argument('-r','--resolution',type=int,required=False,default = 200) - parser.add_argument('--pad_inches',type=int,required=False,default=0) - parser.add_argument('-ph','--print_header', action="store_true",default=False) - parser.add_argument('--num_columns',type=int,default = -1) - parser.add_argument('--full_mode', type=str,default='II') - parser.add_argument('--mask_unplotted_samples', action="store_true", default=False) - parser.add_argument('--add_qr_code', action="store_true", default=False) - - parser.add_argument('-l', '--link', type=str, required=False,default='') - parser.add_argument('-n','--num_words',type=int,required=False,default=5) - parser.add_argument('--x_offset',dest='x_offset',type=int,default = 30) - parser.add_argument('--y_offset',dest='y_offset',type=int,default = 30) - parser.add_argument('--hws',dest='handwriting_size_factor',type=float,default = 0.2) - - parser.add_argument('-ca','--crease_angle',type=int,default=90) - parser.add_argument('-nv','--num_creases_vertically',type=int,default=10) - parser.add_argument('-nh','--num_creases_horizontally',type=int,default=10) - - parser.add_argument('-rot','--rotate',type=int,default=0) - parser.add_argument('-noise','--noise',type=int,default=50) - parser.add_argument('-c','--crop',type=float,default=0.01) - parser.add_argument('-t','--temperature',type=int,default=40000) - - parser.add_argument('--random_resolution',action="store_true",default=False) - parser.add_argument('--random_padding',action="store_true",default=False) - parser.add_argument('--random_grid_color',action="store_true",default=False) - parser.add_argument('--standard_grid_color', type=int, default=5) - parser.add_argument('--calibration_pulse',type=float,default=1) - parser.add_argument('--random_grid_present',type=float,default=1) - parser.add_argument('--random_print_header',type=float,default=0) - parser.add_argument('--random_bw',type=float,default=0) - parser.add_argument('--remove_lead_names',action="store_false",default=True) - parser.add_argument('--lead_name_bbox',action="store_true",default=False) - parser.add_argument('--store_config', type=int, nargs='?', const=1, default=0) - - parser.add_argument('--deterministic_offset',action="store_true",default=False) - parser.add_argument('--deterministic_num_words',action="store_true",default=False) - parser.add_argument('--deterministic_hw_size',action="store_true",default=False) - - parser.add_argument('--deterministic_angle',action="store_true",default=False) - parser.add_argument('--deterministic_vertical',action="store_true",default=False) - parser.add_argument('--deterministic_horizontal',action="store_true",default=False) - - parser.add_argument('--deterministic_rot',action="store_true",default=False) - parser.add_argument('--deterministic_noise',action="store_true",default=False) - parser.add_argument('--deterministic_crop',action="store_true",default=False) - parser.add_argument('--deterministic_temp',action="store_true",default=False) - - parser.add_argument('--fully_random',action='store_true',default=False) - parser.add_argument('--hw_text',action='store_true',default=False) - parser.add_argument('--wrinkles',action='store_true',default=False) - parser.add_argument('--augment',action='store_true',default=False) - parser.add_argument('--lead_bbox',action='store_true',default=False) + parser.add_argument("-i", "--input_directory", type=str, required=True) + parser.add_argument("-o", "--output_directory", type=str, required=True) + parser.add_argument("-se", "--seed", type=int, required=False, default=-1) + parser.add_argument("--num_leads", type=str, default="twelve") + parser.add_argument("--max_num_images", type=int, default=-1) + parser.add_argument("--config_file", type=str, default="config.yaml") + + parser.add_argument("-r", "--resolution", type=int, required=False, default=200) + parser.add_argument("--pad_inches", type=int, required=False, default=0) + parser.add_argument("-ph", "--print_header", action="store_true", default=False) + parser.add_argument("--num_columns", type=int, default=-1) + parser.add_argument("--full_mode", type=str, default="II") + parser.add_argument("--mask_unplotted_samples", action="store_true", default=False) + parser.add_argument("--add_qr_code", action="store_true", default=False) + + parser.add_argument("-l", "--link", type=str, required=False, default="") + parser.add_argument("-n", "--num_words", type=int, required=False, default=5) + parser.add_argument("--x_offset", dest="x_offset", type=int, default=30) + parser.add_argument("--y_offset", dest="y_offset", type=int, default=30) + parser.add_argument( + "--hws", dest="handwriting_size_factor", type=float, default=0.2 + ) + + parser.add_argument("-ca", "--crease_angle", type=int, default=90) + parser.add_argument("-nv", "--num_creases_vertically", type=int, default=10) + parser.add_argument("-nh", "--num_creases_horizontally", type=int, default=10) + + parser.add_argument("-rot", "--rotate", type=int, default=0) + parser.add_argument("-noise", "--noise", type=int, default=50) + parser.add_argument("-c", "--crop", type=float, default=0.01) + parser.add_argument("-t", "--temperature", type=int, default=40000) + + parser.add_argument("--random_resolution", action="store_true", default=False) + parser.add_argument("--random_padding", action="store_true", default=False) + parser.add_argument("--random_grid_color", action="store_true", default=False) + parser.add_argument("--standard_grid_color", type=int, default=5) + parser.add_argument("--calibration_pulse", type=float, default=1) + parser.add_argument("--random_grid_present", type=float, default=1) + parser.add_argument("--random_print_header", type=float, default=0) + parser.add_argument("--random_bw", type=float, default=0) + parser.add_argument("--remove_lead_names", action="store_false", default=True) + parser.add_argument("--lead_name_bbox", action="store_true", default=False) + parser.add_argument("--store_config", type=int, nargs="?", const=1, default=0) + + parser.add_argument("--deterministic_offset", action="store_true", default=False) + parser.add_argument("--deterministic_num_words", action="store_true", default=False) + parser.add_argument("--deterministic_hw_size", action="store_true", default=False) + + parser.add_argument("--deterministic_angle", action="store_true", default=False) + parser.add_argument("--deterministic_vertical", action="store_true", default=False) + parser.add_argument( + "--deterministic_horizontal", action="store_true", default=False + ) + + parser.add_argument("--deterministic_rot", action="store_true", default=False) + parser.add_argument("--deterministic_noise", action="store_true", default=False) + parser.add_argument("--deterministic_crop", action="store_true", default=False) + parser.add_argument("--deterministic_temp", action="store_true", default=False) + + parser.add_argument("--fully_random", action="store_true", default=False) + parser.add_argument("--hw_text", action="store_true", default=False) + parser.add_argument("--wrinkles", action="store_true", default=False) + parser.add_argument("--augment", action="store_true", default=False) + parser.add_argument("--lead_bbox", action="store_true", default=False) + + parser.add_argument("--num_cpus", type=int, default=1) return parser + +def run_on_one_file_parallel(func_args: tuple): + args, full_recording_file, full_header_file, original_output_dir = func_args + filename = full_recording_file + header = full_header_file + args.input_file = os.path.join(args.input_directory, filename) + args.header_file = os.path.join(args.input_directory, header) + args.start_index = -1 + + folder_struct_list = full_header_file.split("/")[:-1] + args.output_directory = os.path.join( + original_output_dir, "/".join(folder_struct_list) + ) + args.encoding = os.path.split(os.path.splitext(filename)[0])[1] + run_single_file(args) + + def run(args): - random.seed(args.seed) - - if os.path.isabs(args.input_directory) == False: - args.input_directory = os.path.normpath(os.path.join(os.getcwd(), args.input_directory)) - if os.path.isabs(args.output_directory) == False: - original_output_dir = os.path.normpath(os.path.join(os.getcwd(), args.output_directory)) - else: - original_output_dir = args.output_directory - - if os.path.exists(args.input_directory) == False or os.path.isdir(args.input_directory) == False: - raise Exception("The input directory does not exist, Please re-check the input arguments!") - - if os.path.exists(original_output_dir) == False: - os.makedirs(original_output_dir) - - i = 0 - full_header_files, full_recording_files = find_records(args.input_directory, original_output_dir) - - for full_header_file, full_recording_file in zip(full_header_files, full_recording_files): - filename = full_recording_file - header = full_header_file - args.input_file = os.path.join(args.input_directory, filename) - args.header_file = os.path.join(args.input_directory, header) - args.start_index = -1 - - folder_struct_list = full_header_file.split('/')[:-1] - args.output_directory = os.path.join(original_output_dir, '/'.join(folder_struct_list)) - args.encoding = os.path.split(os.path.splitext(filename)[0])[1] - - i += run_single_file(args) - - if(args.max_num_images != -1 and i >= args.max_num_images): - break - -if __name__=='__main__': + random.seed(args.seed) + + if os.path.isabs(args.input_directory) == False: + args.input_directory = os.path.normpath( + os.path.join(os.getcwd(), args.input_directory) + ) + if os.path.isabs(args.output_directory) == False: + original_output_dir = os.path.normpath( + os.path.join(os.getcwd(), args.output_directory) + ) + else: + original_output_dir = args.output_directory + + if ( + os.path.exists(args.input_directory) == False + or os.path.isdir(args.input_directory) == False + ): + raise Exception( + "The input directory does not exist, Please re-check the input arguments!" + ) + + if os.path.exists(original_output_dir) == False: + os.makedirs(original_output_dir) + + i = 0 + full_header_files, full_recording_files = find_records( + args.input_directory, original_output_dir + ) + func_args = [] + for full_header_file, full_recording_file in zip( + full_header_files, full_recording_files + ): + func_args.append( + (deepcopy(args), full_recording_file, full_header_file, original_output_dir) + ) + + if args.num_cpus > 1: + with multiprocessing.Pool(processes=args.num_cpus) as pool: + pool.map(run_on_one_file_parallel, func_args) + else: + for func_arg in tqdm(func_args): + run_on_one_file_parallel(func_arg) + + +if __name__ == "__main__": path = os.path.join(os.getcwd(), sys.argv[0]) parentPath = os.path.dirname(path) os.chdir(parentPath)