Skip to content

Commit eda2de6

Browse files
committed
Added optional scaling
1 parent d815fcc commit eda2de6

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

app/streamlit_app.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from sklearn.datasets import fetch_openml, load_digits, load_iris
2222
from sklearn.decomposition import PCA
23+
from sklearn.preprocessing import StandardScaler
2324
from umap import UMAP
2425

2526
from tdamapper.core import aggregate_graph
@@ -349,6 +350,7 @@ def mapper_lens_input_section(X):
349350

350351
def mapper_cover_input_section():
351352
st.header("🌐 Cover")
353+
scale_cover = st.checkbox("Apply Scaling", value=False, key="scale_cover")
352354
cover_type = st.selectbox(
353355
"Type",
354356
options=[
@@ -395,7 +397,7 @@ def mapper_cover_input_section():
395397
elif cover_type == V_COVER_KNN:
396398
knn_k = st.number_input("Neighbors", value=10, min_value=1)
397399
cover = KNNCover(neighbors=knn_k)
398-
return cover
400+
return cover, scale_cover
399401

400402

401403
def mapper_clustering_cover():
@@ -564,6 +566,7 @@ def mapper_clustering_affinityprop():
564566

565567
def mapper_clustering_input_section():
566568
st.header("🧮 Clustering")
569+
scale_clustering = st.checkbox("Apply Scaling", value=False, key="scale_clustering")
567570
clustering_type = st.selectbox(
568571
"Type",
569572
options=[
@@ -592,32 +595,36 @@ def mapper_clustering_input_section():
592595
clustering = mapper_clustering_hdbscan()
593596
elif clustering_type == V_CLUSTERING_AFFINITY_PROPAGATION:
594597
clustering = mapper_clustering_affinityprop()
595-
return clustering
598+
return clustering, scale_clustering
596599

597600

598601
@st.cache_data(
599602
hash_funcs={"tdamapper.learn.MapperAlgorithm": MapperAlgorithm.__repr__},
600603
show_spinner="Computing Mapper",
601604
)
602-
def compute_mapper(mapper, X, y):
605+
def compute_mapper(mapper, X, y, scale_clustering, scale_cover):
603606
logger.info("Generating Mapper graph")
607+
if scale_clustering:
608+
X = StandardScaler().fit_transform(X)
609+
if scale_cover:
610+
y = StandardScaler().fit_transform(y)
604611
mapper_graph = mapper.fit_transform(X, y)
605612
return mapper_graph
606613

607614

608615
def mapper_input_section(X):
609616
lens = mapper_lens_input_section(X)
610617
st.divider()
611-
cover = mapper_cover_input_section()
618+
cover, scale_cover = mapper_cover_input_section()
612619
st.divider()
613-
clustering = mapper_clustering_input_section()
620+
clustering, scale_clustering = mapper_clustering_input_section()
614621
mapper_algo = MapperAlgorithm(
615622
cover=cover,
616623
clustering=clustering,
617624
verbose=False,
618625
n_jobs=-2,
619626
)
620-
mapper_graph = compute_mapper(mapper_algo, X, lens)
627+
mapper_graph = compute_mapper(mapper_algo, X, lens, scale_clustering, scale_cover)
621628
return mapper_graph
622629

623630

@@ -707,20 +714,14 @@ def _hash_mapper_plot(mapper_plot):
707714
hash_funcs={"tdamapper.plot.MapperPlot": _hash_mapper_plot},
708715
show_spinner="Rendering Mapper",
709716
)
710-
def compute_mapper_fig(mapper_plot, colors, node_size, cmap, _agg, agg_name):
717+
def compute_mapper_fig(mapper_plot, colors, _agg, agg_name):
711718
logger.info("Generating Mapper figure")
712719
mapper_fig = mapper_plot.plot_plotly(
713720
colors,
714-
node_size=[
715-
0.0,
716-
node_size / 2.0,
717-
node_size,
718-
node_size * 1.5,
719-
node_size * 2.0,
720-
],
721+
node_size=[0.25 * i for i in range(9)],
721722
agg=_agg,
722723
title=[f"{c}" for c in colors.columns],
723-
cmap=cmap,
724+
cmap=["Jet", "Viridis", "Cividis"],
724725
width=600,
725726
height=600,
726727
)
@@ -735,9 +736,7 @@ def mapper_figure_section(df_X, df_y, mapper_plot):
735736
mapper_fig = compute_mapper_fig(
736737
mapper_plot,
737738
colors=colors,
738-
node_size=1.0,
739739
_agg=agg,
740-
cmap=["Jet", "Viridis", "Cividis"],
741740
agg_name=agg_name,
742741
)
743742
mapper_fig.update_layout(

0 commit comments

Comments
 (0)