-
Notifications
You must be signed in to change notification settings - Fork 0
Supplemental figure - linear regression model #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 59 commits
Commits
Show all changes
74 commits
Select commit
Hold shift + click to select a range
9dfa487
add mean and std over growth phase
chantelleleveille 7680ba1
add features for LRM
chantelleleveille e92a4aa
update added features
chantelleleveille 99a9f26
update documentation
chantelleleveille 7ace64f
update features, simplify
chantelleleveille 94d7cf6
Add transient gr of colony at tp, mean, std
chantelleleveille 845b3fe
update features and scale
chantelleleveille 097546f
update label tables
chantelleleveille 76bafd3
Merge remote-tracking branch 'origin/dev' into features_LRM
chantelleleveille d6fd7ab
Use .values instead of .unique()
chantelleleveille 567d19a
linear model changes + label table updates
bfff7c2
use label tables to get scale
chantelleleveille 36a0d63
remove stray pdb + make movie
09dd626
fix import, add function
chantelleleveille 5a71cf6
get_early_transient_gr_of_whole_colony
chantelleleveille a13c742
Merge branch 'features_LRM' of https://github.com/AllenCell/nuc-morph…
chantelleleveille 2beea44
add lineage features for lrm
chantelleleveille df9e859
add sisters + reorient plots
46dfe28
Merge remote-tracking branch 'origin/dev' into features_LRM
chantelleleveille 140641c
update figure formatting + label tables
chantelleleveille 3ae21bd
update_plot_axis
chantelleleveille 8cb3d39
update label tables
chantelleleveille 2202666
add extrinsic neighborhood feats
chantelleleveille 656b292
add neighborhood feats for lrm
chantelleleveille 8e008e2
worklfow only loads data once
chantelleleveille 1cff138
Update features and names
chantelleleveille f0967eb
add feature correlation plot
chantelleleveille e7b4a6c
update documentation
chantelleleveille e88accf
Normalize sum of mitotic and death events
chantelleleveille 5fa0462
fix string
chantelleleveille 3d870fb
add delta volume as a lineage feature
chantelleleveille 37d4181
update label tables
chantelleleveille 830aedd
add greedy removal workflow
6bf4342
add scripts for greedy removal
c0ec8e1
move some fns to utils
c1fd590
Merge branch 'features_LRM' of https://github.com/AllenCell/nuc-morph…
747bd4f
Merge branch 'features_LRM' of https://github.com/AllenCell/nuc-morph…
chantelleleveille 4619e6d
update workflow to save fit_linear_regression
chantelleleveille 6868f2a
update features into catagories
chantelleleveille fda3a68
update workflow to make r squared matrix
chantelleleveille e55872a
update regression so that you can start at 0
chantelleleveille 4bf140e
in progress
chantelleleveille a262235
Merge remote-tracking branch 'origin/dev' into features_LRM
chantelleleveille fb2ada0
update to new density
chantelleleveille fa2fdb3
Update workflow to have maximum r squared for feature groups
chantelleleveille c247f19
update to use config all feats
chantelleleveille d288135
confirm N is the same for subset cols
chantelleleveille 61f8f21
update workflow
chantelleleveille 016722b
Update main function to avoid error
chantelleleveille f82e19a
change to linear color map for heatmapt
chantelleleveille 13f5250
add clustering to correlation heatmap
chantelleleveille a8be049
add annotations to cluster plot
chantelleleveille 4abb45b
update colormaps
chantelleleveille bea98f1
update workflow
chantelleleveille c840327
organize and document
chantelleleveille 1e19750
remove unused code
chantelleleveille 056ec6c
add figure to run all workflows
chantelleleveille ebc1097
remove unused code
chantelleleveille 9c48be7
update documentation
chantelleleveille 0c6fb0f
PR reveiw changes to add_features
chantelleleveille 2bff19f
PR review changes to global dataset filtering
chantelleleveille 68da6df
return df
chantelleleveille aa1f1cc
update maximum alpha in for ft importance figure
chantelleleveille ed24278
update label tables to sent. case and neighborhood fts
chantelleleveille c8ce61c
save w/ layout tight so pdf plots dont get cut off
chantelleleveille ac26b74
remove hard coded max alpha, default pre-compute
chantelleleveille d816f2d
update imports
chantelleleveille fc4c310
one off change to cell health workflow to print a reported metric!
chantelleleveille 785e495
default to load all features if none specified
chantelleleveille 2982d3c
PR comments - remove rounding
chantelleleveille d237b48
update feature name
chantelleleveille 6558f3c
PR comments - documentation, remove dropna step
chantelleleveille 5440777
PR comment - remove rounding alpha
chantelleleveille 964d45b
update correlation plot
chantelleleveille File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
265 changes: 265 additions & 0 deletions
265
nuc_morph_analysis/analyses/linear_regression/analysis_plots.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,265 @@ | ||
| import seaborn as sns | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| import pandas as pd | ||
| from nuc_morph_analysis.lib.visualization.plotting_tools import get_plot_labels_for_metric | ||
| from nuc_morph_analysis.lib.visualization.notebook_tools import save_and_show_plot | ||
| from nuc_morph_analysis.analyses.linear_regression.linear_regression import fit_linear_regression | ||
| from nuc_morph_analysis.analyses.linear_regression.select_features import (get_feature_list) | ||
|
|
||
| def plot_feature_cluster_correlations(df_track_level_features, feature_list, figdir): | ||
| """ | ||
| Plot clustermap of feature correlations. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| df_track_level_features : pd.DataFrame | ||
| DataFrame containing track level features | ||
| feature_list : list | ||
| List of features to include in the clustermap | ||
| Output from get_feature_list | ||
| figdir : str | ||
| Directory to save the figure | ||
|
|
||
| Returns | ||
| ------- | ||
| Figure | ||
| """ | ||
| data = df_track_level_features[feature_list] | ||
|
|
||
| cluster_grid = sns.clustermap(data.corr(), vmin=-1, vmax=1, cmap='BrBG', #'vlag' red to blue | ||
| cbar_pos=(0.4, 0.9, 0.3, 0.02), | ||
| cbar_kws={"orientation": "horizontal"}, | ||
| annot=True, fmt=".1f", annot_kws={"size": 12}, | ||
| figsize=(18, 18)) | ||
|
|
||
|
|
||
| # Hide the dendrograms | ||
| cluster_grid.ax_row_dendrogram.set_visible(False) | ||
| cluster_grid.ax_col_dendrogram.set_visible(False) | ||
|
|
||
| # Get the reordered labels using the dendrogram information | ||
| reordered_column_index = cluster_grid.dendrogram_col.reordered_ind | ||
| reordered_labels = [get_plot_labels_for_metric(data.columns[i])[1] for i in reordered_column_index] | ||
|
|
||
| # add a number to the end of each label | ||
| reordered_labels = [f'{label} ({i+1})' for i, label in enumerate(reordered_labels)] | ||
|
|
||
| # for the bottom label just use numbers | ||
| numbered_labels = [f'{i+1}' for i in range(len(reordered_labels))] | ||
|
|
||
| # Ensure the number of labels matches the number of ticks | ||
| cluster_grid.ax_heatmap.set_xticks([x + 0.5 for x in range(len(numbered_labels))]) | ||
| cluster_grid.ax_heatmap.set_xticklabels(numbered_labels, rotation=0) | ||
| cluster_grid.ax_heatmap.set_yticks([y + 0.5 for y in range(len(reordered_labels))]) | ||
| cluster_grid.ax_heatmap.set_yticklabels(reordered_labels, rotation=0) | ||
|
|
||
| # Adjust the padding between the labels and the heatmap | ||
| cluster_grid.ax_heatmap.tick_params(axis='x', labelsize=12, width=0.7) | ||
| cluster_grid.ax_heatmap.tick_params(axis='y', labelsize=12, width=0.7, labelright=False, labelleft=True, left=True, right=False) | ||
|
|
||
| save_and_show_plot(f'{figdir}/feature_correlation_clustermap', figure=cluster_grid.fig, dpi=300) | ||
|
|
||
|
|
||
| def run_regression(df_track_level_features, target, features, name, alpha, figdir): | ||
| """ | ||
| Run linear regression on the given dataset and return the results. | ||
|
|
||
| Parameters: | ||
| ---------- | ||
| df_track_level_features (pd.DataFrame): DataFrame containing the track level features. | ||
| target (str): The target variable for regression. | ||
| features (list): List of features to be used for regression. | ||
| name (str): Name of the feature group. | ||
| alpha (list): List of alpha values for regularization. | ||
| figdir (str): Directory path to save the figures. | ||
|
|
||
| Returns: | ||
| -------- | ||
| dict: A dictionary containing the target, feature group name, mean R-squared value, | ||
| standard deviation of R-squared values, alpha value, and the features used. | ||
| """ | ||
| _, all_test_sc, _ = fit_linear_regression( | ||
| df_track_level_features, | ||
| cols=get_feature_list(features, target), | ||
| target=target, | ||
| alpha=alpha, | ||
| tol=0.04, | ||
| save_path=figdir, | ||
| save=False, | ||
| multiple_predictions=False | ||
| ) | ||
| print(f"Target {target}, Alpha: {alpha}. Feature group: {name}") | ||
| r_squared = round(all_test_sc["Test r$^2$"].mean(), 3) | ||
| std = round(all_test_sc["Test r$^2$"].std(), 3) | ||
| return {'target': target, 'feature_group': name, 'r_squared': r_squared, 'stdev': std, 'alpha': 0, 'feats_used': get_feature_list(features, target)} | ||
|
|
||
| def run_regression_workflow(targets, feature_configs, df_track_level_features, figdir, alpha): | ||
| """ | ||
| Run the regression workflow for multiple targets and feature configurations. | ||
|
|
||
| Parameters: | ||
| ---------- | ||
| targets (list): List of target variables for regression. | ||
| feature_configs (dict): Dictionary where keys are feature group names and values are lists of features. | ||
| df_track_level_features (pd.DataFrame): DataFrame containing the track level features. | ||
| figdir (str): Directory path to save the figures and results. | ||
| alpha (float): Alpha value for regularization. | ||
|
|
||
| Returns: | ||
| -------- | ||
| pd.DataFrame: DataFrame containing the results of the regression workflow, including target, | ||
| R-squared value, standard deviation, feature group, alpha value, and features used. | ||
| """ | ||
| df = pd.DataFrame(columns=['target', 'r_squared', 'stdev', 'feature_group', 'alpha', 'feats_used']) | ||
|
|
||
| for target in targets: | ||
| for name, features in feature_configs.items(): | ||
| result = run_regression(df_track_level_features, target, features, name, [alpha], figdir) | ||
| df = df.append(result, ignore_index=True) | ||
| df.to_csv(f"{figdir}r_squared_results.csv") | ||
|
|
||
| df['num_feats_used'] = df['feats_used'].apply(lambda x: len(x)) | ||
|
|
||
| return df | ||
|
|
||
|
|
||
| def plot_heatmap(df, figdir, cmap='coolwarm'): | ||
| """ | ||
| Plot heatmap of r_squared values for different feature groups. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| df : pd.DataFrame | ||
| DataFrame containing r_squared values for different feature groups | ||
| figdir : str | ||
| Directory to save the figure | ||
| cmap: str | ||
| linear colormap | ||
|
|
||
| Returns | ||
| ------- | ||
| Figure | ||
| """ | ||
| # Split the 'feature_group' column into two | ||
| df[['start_lifetime', 'intrinsic_extrinsic']] = df['feature_group'].str.split('_', expand=True) | ||
| def replace_values(val): | ||
| if val in ['all', 'features']: | ||
| return 'all_features' | ||
| else: | ||
| return val | ||
| df[['start_lifetime', 'intrinsic_extrinsic']] = df[['start_lifetime', 'intrinsic_extrinsic']].applymap(replace_values) | ||
| for index, row in df.iterrows(): | ||
| if row['start_lifetime'] == 'intrinsic': | ||
| df.at[index, 'intrinsic_extrinsic'] = 'intrinsic' | ||
| df.at[index, 'start_lifetime'] = 'both' | ||
| elif row['start_lifetime'] == 'extrinsic': | ||
| df.at[index, 'intrinsic_extrinsic'] = 'extrinsic' | ||
| df.at[index, 'start_lifetime'] = 'both' | ||
|
|
||
| for target, df_target in df.groupby('target'): | ||
| pivot_df = df_target.pivot(index='start_lifetime', columns='intrinsic_extrinsic', values='r_squared') | ||
| pivot_df_std = df_target.pivot(index='start_lifetime', columns='intrinsic_extrinsic', values='stdev') | ||
|
|
||
| fig, ax = plt.subplots(figsize=(10, 8)) | ||
|
|
||
| sns.heatmap(pivot_df, annot=False, cmap=cmap, ax=ax, vmin=0, vmax=0.5) | ||
|
|
||
| first_element = True | ||
| for text_x in range(pivot_df.shape[0]): | ||
| for text_y in range(pivot_df.shape[1]): | ||
| value = pivot_df.iloc[text_x, text_y] | ||
| std_dev = pivot_df_std.iloc[text_x, text_y] | ||
| if not np.isnan(value): | ||
| color = 'white' if first_element else 'black' | ||
| ax.text(text_y+0.5, text_x+0.5, f'{value:.2f} ± {std_dev:.2f}', | ||
| horizontalalignment='center', | ||
| verticalalignment='center', | ||
| color=color, fontsize=16) | ||
| first_element = False | ||
|
|
||
| ax.set_xticklabels(['','Extrinsic','Intrinsic']) | ||
| ax.xaxis.tick_top() | ||
| ax.set_yticklabels(['', 'Both', 'Lifetime', 'Start of growth'], rotation=0) | ||
|
|
||
| ax.set_xlabel('') | ||
| ax.set_ylabel('') | ||
| ax.tick_params(axis='both', which='both', length=0) | ||
| title = ax.set_title(f'Target: {get_plot_labels_for_metric(target)[1]}', loc='left') | ||
| title.set_position([-0.1,1]) | ||
| save_and_show_plot(f'{figdir}{target}_prediction_r_squared_matrix_alpha_{df.alpha[0]}') | ||
|
|
||
|
|
||
| def plot_feature_contribution(coef_alpha, test_sc, perms, target, fig_height, figdir): | ||
| """ | ||
| For a given target, plot feature importance for each feature in the linear model at a specified alpha. | ||
| Features that touch 0 are considered not important and are excluded from the plot. | ||
|
|
||
| Paramaters | ||
| ------- | ||
| coef_alpha: pd.DataFrame | ||
| DataFrame containing the coefficient importance for each feature | ||
| test_sc: pd.DataFrame | ||
| DataFrame containing the test r2 scores | ||
| perms: pd.DataFrame | ||
| DataFrame containing the permutation test results | ||
| target: str | ||
| Prediction feature | ||
| fig_height: int | ||
| Height of the figure based on number of important features | ||
| save_path: str | ||
| Path to save the plot | ||
|
|
||
| Returns | ||
| ------- | ||
| Figure | ||
| """ | ||
|
|
||
| alpha = coef_alpha["alpha"].unique()[0] | ||
chantelleleveille marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| p_value = round(perms["p_value"].item(), 3) | ||
| test_r2_mean = round(test_sc["Test r$^2$"].mean(), 2) | ||
| test_r2_std = round(test_sc["Test r$^2$"].std() / 2, 2) | ||
|
|
||
| for col, df_col in coef_alpha.groupby("Column"): | ||
| lower_bound = df_col["Coefficient Importance"].mean() - df_col["Coefficient Importance"].std() | ||
| upper_bound = df_col["Coefficient Importance"].mean() + df_col["Coefficient Importance"].std() | ||
| if lower_bound < 0 and upper_bound > 0 or df_col["Coefficient Importance"].mean() == 0: | ||
| coef_alpha = coef_alpha[coef_alpha["Column"] != col] | ||
chantelleleveille marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| coef_alpha['Magnitude coefficient importance'] = abs(coef_alpha['Coefficient Importance']) | ||
| coef_alpha['Sign'] = coef_alpha['Coefficient Importance'].apply(lambda x: 'Positive coefficient' if x > 0 else 'Negative coefficient') | ||
|
|
||
| coef_alpha['Mean Magnitude'] = coef_alpha.groupby('Column')['Magnitude coefficient importance'].transform('mean') | ||
| coef_alpha = coef_alpha.sort_values('Mean Magnitude', ascending=False).drop(columns=['Mean Magnitude']) | ||
|
|
||
| plt.figure(figsize=(4,fig_height*.5)) | ||
| ax = sns.barplot( | ||
| data=coef_alpha, | ||
| y="Column", | ||
| x="Magnitude coefficient importance", | ||
| hue="Sign", | ||
| palette={'Positive coefficient': '#156082', 'Negative coefficient': 'grey'}, | ||
| errorbar="sd", | ||
| width=0.7, | ||
| native_scale=True) | ||
|
|
||
| for patch in ax.patches: | ||
| patch.set_edgecolor('black') | ||
| patch.set_linewidth(1.5) | ||
| current_height = patch.get_height() | ||
| desired_height = 0.7 | ||
| patch.set_height(desired_height) | ||
| patch.set_y(patch.get_y() + (current_height - desired_height) * 0.5) | ||
|
|
||
| ax.spines['top'].set_visible(False) | ||
| ax.spines['right'].set_visible(False) | ||
|
|
||
| plt.ylabel("") | ||
| plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left') | ||
| if target == 'delta_volume_BC': | ||
| ax.get_legend().remove() | ||
| plt.title(f"Target: {get_plot_labels_for_metric(target)[1]}, alpha={alpha}, test r\u00B2={test_r2_mean}±{test_r2_std}, P={p_value}") | ||
| label_list = [get_plot_labels_for_metric(col)[1] for col in coef_alpha["Column"].unique()] | ||
|
|
||
| plt.yticks(ticks=range(len(label_list)), labels=label_list) | ||
| save_and_show_plot(f'{figdir}/coefficients_{target}_alpha_{alpha}', dpi=300, bbox_inches='tight') | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.