diff --git a/thema/__init__.py b/thema/__init__.py index 7d72e4fb..e146bacb 100644 --- a/thema/__init__.py +++ b/thema/__init__.py @@ -27,6 +27,91 @@ from .multiverse.universe.utils.starGraph import starGraph from .thema import Thema +import logging +import warnings + +# Suppress sklearn deprecation warnings about force_all_finite -> ensure_all_finite +warnings.filterwarnings( + "ignore", + message=".*'force_all_finite' was renamed to 'ensure_all_finite'.*", + category=FutureWarning, + module="sklearn.*", +) + + +def enable_logging(level="INFO"): + """ + Enable thema logging for interactive use (e.g., notebooks). + + Parameters + ---------- + level : str, optional + Logging level ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL') + Default is 'INFO' for moderate verbosity. + Use 'DEBUG' for detailed operational info. + Use 'WARNING' for warnings and errors only. + Use 'ERROR' for errors only. + + Examples + -------- + >>> import thema + >>> thema.enable_logging('DEBUG') # Detailed logging + >>> thema.enable_logging('INFO') # Moderate logging + >>> thema.enable_logging('WARNING') # Warnings/errors only + """ + thema_logger = logging.getLogger("thema") + + for handler in thema_logger.handlers[:]: + if not isinstance(handler, logging.NullHandler): + thema_logger.removeHandler(handler) + + handler = logging.StreamHandler() + formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + thema_logger.addHandler(handler) + + # Set level + level_map = { + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + log_level = level_map.get(level.upper(), logging.INFO) + thema_logger.setLevel(log_level) + + for name in list(logging.Logger.manager.loggerDict.keys()): + if name.startswith("thema."): + child_logger = logging.getLogger(name) + # Reset level to NOTSET so parent logger controls the level + child_logger.setLevel(logging.NOTSET) + + thema_logger.propagate = False + + print(f"Thema logging enabled at {level.upper()} level") + + +def disable_logging(): + """ + Disable thema logging (return to quiet mode). + + Examples + -------- + >>> import thema + >>> thema.disable_logging() + """ + thema_logger = logging.getLogger("thema") + + for handler in thema_logger.handlers[:]: + if not isinstance(handler, logging.NullHandler): + thema_logger.removeHandler(handler) + + thema_logger.setLevel(logging.ERROR) + + print("Thema logging disabled (errors only)") + + # Package metadata __version__ = "0.1.3" __author__ = "Krv-Analytics" @@ -41,4 +126,6 @@ "Star", "Galaxy", "starGraph", + "enable_logging", + "disable_logging", ] diff --git a/thema/config.py b/thema/config.py index 3741ea7d..dd7855bb 100644 --- a/thema/config.py +++ b/thema/config.py @@ -146,3 +146,27 @@ class jmapObservatoryConfig: star_to_observatory = { "jmapStar": "jmapObservatoryConfig", } + +# Map from filter YAML tags to filter functions and their parameter names +filter_configs = { + "component_count": { + "function": "component_count_filter", + "params": {"target_components": 1} + }, + "component_count_range": { + "function": "component_count_range_filter", + "params": {"min_components": 1, "max_components": 10} + }, + "minimum_nodes": { + "function": "minimum_nodes_filter", + "params": {"min_nodes": 3} + }, + "minimum_edges": { + "function": "minimum_edges_filter", + "params": {"min_edges": 2} + }, + "minimum_unique_items": { + "function": "minimum_unique_items_filter", + "params": {"min_unique_items": 10} + } +} diff --git a/thema/multiverse/system/inner/moon.py b/thema/multiverse/system/inner/moon.py index 087f072e..62e9a432 100644 --- a/thema/multiverse/system/inner/moon.py +++ b/thema/multiverse/system/inner/moon.py @@ -1,9 +1,9 @@ # File: multiverse/system/inner/moon.py -# Last Update: 05/15/24 -# Updated By: JW +# Last Update: 10/15/25 +# Updated By: SG import pickle -import warnings +import logging import category_encoders as ce import pandas as pd @@ -12,6 +12,9 @@ from ....core import Core from . import inner_utils +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + class Moon(Core): """ @@ -122,81 +125,79 @@ def __init__( self.seed = seed self.imputeData = None - def fit(self): - """ - Performs the cleaning procedure according to the constructor arguments. - Initializes the imputeData member as a DataFrame, which is a scaled, - numeric, and complete representation of the original raw data set. + # Log initial state + logger.debug(f"Moon initialized with data shape: {self.data.shape}") + logger.debug(f"Drop columns: {self.dropColumns}") + logger.debug(f"Impute columns: {self.imputeColumns}") + logger.debug(f"Impute methods: {self.imputeMethods}") + logger.debug( + f"Encoding: {self.encoding}, Scaler: {self.scaler}, Seed: {self.seed}" + ) - Examples - ---------- - >>> moon = Moon() - >>> moon.fit() - """ + def fit(self): + # Add imputed flags + self.imputeData = inner_utils.add_imputed_flags(self.data, self.imputeColumns) + logger.debug("Added imputed flags to columns") + logger.debug(f"Data shape after adding flags: {self.imputeData.shape}") - self.imputeData = inner_utils.add_imputed_flags( - self.data, self.imputeColumns - ) + # Apply imputation for index, column in enumerate(self.imputeColumns): impute_function = getattr(inner_utils, self.imputeMethods[index]) - self.imputeData[column] = impute_function( - self.data[column], self.seed + self.imputeData[column] = impute_function(self.data[column], self.seed) + logger.debug( + f"Column '{column}' imputed using '{self.imputeMethods[index]}'. " + f"NaNs remaining: {self.imputeData[column].isna().sum()}" ) - self.dropColumns = [ - col for col in self.dropColumns if col in self.data.columns - ] - # Drop Columns - if not self.dropColumns == []: - self.imputeData = self.data.drop(columns=self.dropColumns) + # Drop specified columns + self.dropColumns = [col for col in self.dropColumns if col in self.data.columns] + if self.dropColumns: + before_drop = self.imputeData.shape + self.imputeData = self.imputeData.drop(columns=self.dropColumns) + logger.debug( + f"Dropped columns: {self.dropColumns}. Shape before: {before_drop}, after: {self.imputeData.shape}" + ) - # Drop Rows with Nans + # Drop rows with NaNs + nan_cols = self.imputeData.columns[self.imputeData.isna().any()] + logger.debug(f"Columns with NaN values before dropping rows: {list(nan_cols)}") self.imputeData.dropna(axis=0, inplace=True) + logger.debug(f"Shape after dropping rows with NaNs: {self.imputeData.shape}") - if type(self.encoding) == str: + # Ensure encoding is a list + if isinstance(self.encoding, str): self.encoding = [ self.encoding for _ in range( - len( - self.imputeData.select_dtypes( - include=["object"] - ).columns - ) + len(self.imputeData.select_dtypes(include=["object"]).columns) ) ] # Encoding - assert len(self.encoding) == len( - self.imputeData.select_dtypes(include=["object"]).columns - ), f"length of encoding: {len(self.encoding)}, length of cat variables: {len(self.imputeData.select_dtypes(include=['object']).columns)}" - - for i, column in enumerate( - self.imputeData.select_dtypes(include=["object"]).columns - ): - encoding = self.encoding[i] - - if encoding == "one_hot": - if self.imputeData[column].dtype == object: - self.imputeData = pd.get_dummies( - self.imputeData, prefix=f"OH_{column}", columns=[column] - ) - - elif encoding == "integer": - if self.imputeData[column].dtype == object: - vals = self.imputeData[column].values - self.imputeData[column] = inner_utils.integer_encoder(vals) - - elif encoding == "hash": - if self.imputeData[column].dtype == object: - hashing_encoder = ce.HashingEncoder( - cols=[column], n_components=10 - ) - self.imputeData = hashing_encoder.fit_transform( - self.imputeData - ) - - else: - pass + cat_cols = self.imputeData.select_dtypes(include=["object"]).columns + assert len(self.encoding) == len(cat_cols), ( + f"length of encoding: {len(self.encoding)}, " + f"length of categorical variables: {len(cat_cols)}" + ) + for i, column in enumerate(cat_cols): + encoding_method = self.encoding[i] + if encoding_method == "one_hot" and self.imputeData[column].dtype == object: + self.imputeData = pd.get_dummies( + self.imputeData, prefix=f"OH_{column}", columns=[column] + ) + logger.debug(f"Column '{column}' one-hot encoded") + + elif ( + encoding_method == "integer" and self.imputeData[column].dtype == object + ): + vals = self.imputeData[column].values + self.imputeData[column] = inner_utils.integer_encoder(vals) + logger.debug(f"Column '{column}' integer encoded") + + elif encoding_method == "hash" and self.imputeData[column].dtype == object: + hashing_encoder = ce.HashingEncoder(cols=[column], n_components=10) + self.imputeData = hashing_encoder.fit_transform(self.imputeData) + logger.debug(f"Column '{column}' hash encoded") # Scaling assert self.scaler in ["standard"], "Invalid Scaler" @@ -206,6 +207,9 @@ def fit(self): scaler.fit_transform(self.imputeData), columns=list(self.imputeData.columns), ) + logger.debug( + f"Data scaled using StandardScaler. Final shape: {self.imputeData.shape}" + ) def save(self, file_path): """ @@ -224,3 +228,4 @@ def save(self, file_path): """ with open(file_path, "wb") as f: pickle.dump(self, f) + logger.debug(f"Moon object saved to {file_path}") diff --git a/thema/multiverse/system/inner/planet.py b/thema/multiverse/system/inner/planet.py index 41d6ca08..194da6f3 100644 --- a/thema/multiverse/system/inner/planet.py +++ b/thema/multiverse/system/inner/planet.py @@ -5,17 +5,27 @@ import os import pickle import random +import logging +import time import numpy as np import pandas as pd from omegaconf import ListConfig, OmegaConf from ....core import Core -from ....utils import function_scheduler +from ....utils import ( + function_scheduler, + get_current_logging_config, + configure_child_process_logging, +) from .inner_utils import clean_data_filename from .moon import Moon +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + class Planet(Core): """ Perturb, Label And Navigate Existsing Tabulars @@ -273,7 +283,7 @@ def __init__( if self.outDir is not None and not os.path.isdir(self.outDir): try: - os.makedirs(outDir) + os.makedirs(str(outDir)) except Exception as e: print(e) @@ -316,7 +326,9 @@ def __init__( elif imputeColumns == "auto": - self.imputeColumns = self.data.columns[self.data.isna().any()].tolist() + self.imputeColumns = self.data.columns[ + self.data.isna().any() + ].tolist() elif type(imputeColumns) == ListConfig or type(imputeColumns) == list: self.imputeColumns = imputeColumns @@ -328,14 +340,18 @@ def __init__( self.imputeColumns = [] if imputeMethods is None or imputeMethods == "None": - self.imputeMethods = ["drop" for _ in range(len(self.imputeColumns))] + self.imputeMethods = [ + "drop" for _ in range(len(self.imputeColumns)) + ] elif imputeMethods == "auto": self.imputeMethods = self.get_recomended_sampling_method() elif type(imputeMethods) == str: if not imputeMethods in supported_imputeMethods: print("Invalid impute methods. Defaulting to 'drop'") imputeMethods = "drop" - self.imputeMethods = [imputeMethods for _ in range(len(self.imputeColumns))] + self.imputeMethods = [ + imputeMethods for _ in range(len(self.imputeColumns)) + ] else: assert len(imputeMethods) == len( self.imputeColumns @@ -359,6 +375,15 @@ def __init__( self.numSamples = 1 self.seeds = [42] + # Log Planet configuration + logger.debug(f"Planet initialized with data shape: {self.data.shape}") + logger.debug(f"Number of samples to generate: {self.numSamples}") + logger.debug(f"Impute columns: {self.imputeColumns}") + logger.debug(f"Impute methods: {self.imputeMethods}") + logger.debug(f"Drop columns: {self.dropColumns}") + logger.debug(f"Encoding: {self.encoding}, Scaler: {self.scaler}") + logger.debug(f"Output directory: {self.outDir}") + def get_missingData_summary(self) -> dict: """ Get a summary of missing data in the columns of the 'data' dataframe. @@ -499,28 +524,83 @@ def fit(self): >>> planet.imputeData.to_pickle("myCleanData") """ + logger.info( + f"Starting Planet.fit() – creating {self.numSamples} Moon object(s)" + ) + logger.debug(f"Using seeds: {self.seeds}") + logger.debug(f"Output directory: {self.outDir}") + + # Get current logging config to pass to child processes + logging_config = get_current_logging_config() + assert len(self.seeds) == self.numSamples subprocesses = [] for i in range(self.numSamples): - cmd = (self._instantiate_moon, i) + logger.debug( + f"Preparing Moon {i+1}/{self.numSamples} with seed {self.seeds[i]}" + ) + cmd = (self._instantiate_moon, i, logging_config) subprocesses.append(cmd) - function_scheduler( + # Pre-count outputs for delta reporting + pre_count = 0 + try: + if self.outDir and os.path.isdir(self.outDir): + pre_count = len( + [f for f in os.listdir(self.outDir) if f.endswith(".pkl")] + ) + except Exception: + pass + + workers = min(4, self.numSamples) + logger.info( + f"Launching {len(subprocesses)} Moon process(es) with max {workers} worker(s)…" + ) + t0 = time.perf_counter() + results = function_scheduler( subprocesses, - max_workers=min(4, self.numSamples), + max_workers=workers, out_message="SUCCESS: Imputation(s)", resilient=True, verbose=self.verbose, ) + t1 = time.perf_counter() - def _instantiate_moon(self, id): + # Post-count outputs for delta reporting + created = None + try: + if self.outDir and os.path.isdir(self.outDir): + post_count = len( + [f for f in os.listdir(self.outDir) if f.endswith(".pkl")] + ) + created = max(0, post_count - pre_count) + except Exception: + created = None + + total = len(subprocesses) + duration = t1 - t0 + if isinstance(results, list) and len(results) == total: + logger.info( + f"Planet.fit() complete in {duration:.2f}s – processed {total} Moon object(s){'' if created is None else f', created ~{created} file(s)'}" + ) + else: + logger.info( + f"Planet.fit() complete in {duration:.2f}s – processed {total} Moon object(s)." + ) + + def _instantiate_moon(self, id, logging_config): """ Helper function for the fit() method. See `fit()` for more details. + This method creates and processes a single Moon instance with proper + logging configuration for multiprocessing environments. + Parameters ---------- id : int Identifier for the Moon instance. + logging_config : dict or None + Logging configuration from parent process. Returns ------- @@ -529,8 +609,10 @@ def _instantiate_moon(self, id): Examples -------- >>> planet = Planet() - >>> planet._instantiate_moon(1) + >>> planet._instantiate_moon(1, logging_config) """ + # Configure logging in this child process + configure_child_process_logging(logging_config) if self.seeds is None: self.seeds = dict() @@ -548,7 +630,9 @@ def _instantiate_moon(self, id): ) my_moon.fit() - filename_without_extension, extension = os.path.splitext(self.get_data_path()) + filename_without_extension, extension = os.path.splitext( + self.get_data_path() + ) data_name = filename_without_extension.split("/")[-1] file_name = clean_data_filename( data_name=data_name, @@ -556,7 +640,8 @@ def _instantiate_moon(self, id): scaler=self.scaler, encoding=self.encoding, ) - output_filepath = os.path.join(self.outDir, file_name) + output_dir = str(self.outDir) + output_filepath = os.path.join(output_dir, file_name) my_moon.save(file_path=output_filepath) @@ -595,42 +680,42 @@ def getParams(self) -> dict: def writeParams_toYaml(self, YAML_PATH=None): """ - Write the specified parameters to a YAML file. + Write or create a YAML file with the Planet parameters. Parameters ---------- - YAML_PATH : str - The path to an existing YAML file. + YAML_PATH : str, optional + Path to an existing or new YAML file. Returns ------- None - - Examples - -------- - >>> planet = Planet() - >>> planet.writeParams_toYaml("config.yaml") - YAML file successfully updated """ if YAML_PATH is None and self.YAML_PATH is not None: YAML_PATH = self.YAML_PATH - if YAML_PATH is None and self.YAML_PATH is None: + if YAML_PATH is None: raise ValueError("Please provide a valid filepath to YAML") - # Check if file exists and is correct type - if not os.path.isfile(YAML_PATH): - raise TypeError("File path does not point to a YAML file") - with open(YAML_PATH, "r") as f: - params = OmegaConf.load(f) + # If the file exists, load it; otherwise start a new config + if os.path.isfile(YAML_PATH): + params = OmegaConf.load(YAML_PATH) + else: + params = OmegaConf.create() + # Update with this object's parameters params.Planet = self.getParams() params.Planet.pop("outDir", None) params.Planet.pop("data", None) - with open(YAML_PATH, "w") as f: - OmegaConf.save(params, f) + # Ensure directory exists + os.makedirs(os.path.dirname(YAML_PATH), exist_ok=True) - print("YAML file successfully updated") + # Save the YAML file + file_exists_before = os.path.isfile(YAML_PATH) + OmegaConf.save(params, YAML_PATH) + print( + f"YAML file successfully {'updated' if file_exists_before else 'created'} at {YAML_PATH}" + ) def save(self, file_path): """ diff --git a/thema/multiverse/system/outer/comet.py b/thema/multiverse/system/outer/comet.py index af33ed83..88ae41aa 100644 --- a/thema/multiverse/system/outer/comet.py +++ b/thema/multiverse/system/outer/comet.py @@ -2,11 +2,16 @@ # Last Update: 05/15/24 # Updated by: JW +import logging import pickle from abc import abstractmethod from ....core import Core +# Configure module logger +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + class Comet(Core): """ @@ -59,6 +64,7 @@ def __init__(self, data_path: str, clean_path: str): super().__init__( data_path=data_path, clean_path=clean_path, projection_path=None ) + @abstractmethod def fit(self): @@ -102,4 +108,6 @@ def save(self, file_path): with open(file_path, "wb") as f: pickle.dump(self, f) except Exception as e: + # Always log errors regardless of level + logger.error(f"Failed to save Comet to {file_path}: {str(e)}") print(e) diff --git a/thema/multiverse/system/outer/oort.py b/thema/multiverse/system/outer/oort.py index d4337102..1e36cd48 100644 --- a/thema/multiverse/system/outer/oort.py +++ b/thema/multiverse/system/outer/oort.py @@ -5,14 +5,25 @@ import glob import importlib import itertools +import logging import os import pickle +import time from omegaconf import OmegaConf from .... import config from ....core import Core -from ....utils import create_file_name, function_scheduler +from ....utils import ( + create_file_name, + function_scheduler, + get_current_logging_config, + configure_child_process_logging, +) + +# Configure module logger +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) class Oort(Core): @@ -208,6 +219,16 @@ def __init__( except Exception as e: print(e) + # Log Oort initialization + clean_files = len(os.listdir(self.cleanDir)) + logger.info( + f"Oort initialized with {len(self.params)} projector type(s) and {clean_files} clean data file(s)" + ) + logger.debug(f"Clean directory: {self.cleanDir}") + logger.debug(f"Output directory: {self.outDir}") + for proj_name, proj_params in self.params.items(): + logger.debug(f"Projector '{proj_name}' parameters: {proj_params}") + def fit(self): """ Configure and run your projections. @@ -225,8 +246,16 @@ def fit(self): >>> oort = Oort() >>> oort.fit() """ + logger.info( + f"Starting Oort.fit() – processing {len(self.params)} projector type(s)" + ) + + # Get current logging config to pass to child processes + logging_config = get_current_logging_config() + subprocesses = [] for projectorName, projectorParamsDict in self.params.items(): + logger.debug(f"Setting up projector: {projectorName}") projConfig = config.tag_to_class[projectorName] cfg = getattr(config, projConfig) module = importlib.import_module(cfg.module, package="thema") @@ -234,6 +263,10 @@ def fit(self): file_pattern = os.path.join(self.cleanDir, "*.pkl") valid_files = glob.glob(file_pattern) + logger.debug( + f"Found {len(valid_files)} clean files for projector '{projectorName}'" + ) + for i, cleanFile in enumerate(valid_files): parameter_combinations = itertools.product( itertools.product( @@ -245,6 +278,7 @@ def fit(self): ) ) cleanFile = os.path.join(self.cleanDir, cleanFile) + param_count = 0 for j, combination in enumerate(parameter_combinations): projectorParameters = { key: value @@ -260,20 +294,82 @@ def fit(self): projectorParameters, projectorName, f"{j}_{i}", + logging_config, ) subprocesses.append(cmd) + param_count += 1 + logger.debug( + f"Generated {param_count} parameter combinations for clean file {i+1}" + ) + + # Pre-count outputs for delta reporting + pre_count = 0 + try: + if self.outDir and os.path.isdir(self.outDir): + pre_count = len( + [f for f in os.listdir(self.outDir) if f.endswith(".pkl")] + ) + except Exception: + pass # TODO: optimize max-Works based on OS availability - function_scheduler( + total_projections = len(subprocesses) + workers = 4 + logger.info( + f"Launching {total_projections} projection process(es) with max {workers} worker(s)…" + ) + + t0 = time.perf_counter() + results = function_scheduler( subprocesses, 4, "SUCCESS: Projection(s)", resilient=True, verbose=self.verbose, ) + t1 = time.perf_counter() + + # Log completion stats + if results: + failed_count = sum(1 for r in results if r is False) + success_count = total_projections - failed_count + success_rate = ( + (success_count / total_projections * 100) + if total_projections > 0 + else 0 + ) + # Post-count outputs for delta reporting + created = None + try: + if self.outDir and os.path.isdir(self.outDir): + post_count = len( + [ + f + for f in os.listdir(self.outDir) + if f.endswith(".pkl") + ] + ) + created = max(0, post_count - pre_count) + except Exception: + created = None + + logger.info( + f"Oort.fit() complete in {t1 - t0:.2f}s – {success_count}/{total_projections} ({success_rate:.1f}%) successful{'' if created is None else f', created ~{created} file(s)'}" + ) + if failed_count > 0: + logger.warning(f"{failed_count} projections failed") + else: + logger.info(f"Oort.fit() complete in {t1 - t0:.2f}s") def _instantiate_projection( - self, data, cleanFile, projector, projectorParameters, projectorName, id + self, + data, + cleanFile, + projector, + projectorParameters, + projectorName, + id, + logging_config, ): """ Helper function for the fit() method. Creates a projectile instance @@ -293,10 +389,13 @@ def _instantiate_projection( Name of the projector class. id : int Identifier. + logging_config : dict or None + Logging configuration from parent process. Returns ------- - None + bool + True if successful, False otherwise. See Also -------- @@ -311,17 +410,27 @@ def _instantiate_projection( >>> projectorParameters = {"param1": 10, "param2": "abc"} >>> projectorName = "MyProjector" >>> id = 1 - >>> _instantiate_projection(data, cleanFile, projector, projectorParameters, projectorName, id) + >>> _instantiate_projection(data, cleanFile, projector, projectorParameters, projectorName, id, logging_config) """ - my_projector = projector( - data_path=data, clean_path=cleanFile, **projectorParameters - ) - my_projector.fit() - output_file = create_file_name( - className=projectorName, classParameters=projectorParameters, id=id - ) - output_file = os.path.join(self.outDir, output_file) - my_projector.save(output_file) + # Configure logging in this child process + configure_child_process_logging(logging_config) + + try: + my_projector = projector( + data_path=data, clean_path=cleanFile, **projectorParameters + ) + my_projector.fit() + output_file = create_file_name( + className=projectorName, + classParameters=projectorParameters, + id=id, + ) + output_file = os.path.join(self.outDir, output_file) + my_projector.save(output_file) + return True + except Exception as e: + logger.error(f"Projection {projectorName} #{id} failed: {str(e)}") + return False def getParams(self): """ diff --git a/thema/multiverse/system/outer/projectiles/pcaProj.py b/thema/multiverse/system/outer/projectiles/pcaProj.py index 797dc9ca..948711fe 100644 --- a/thema/multiverse/system/outer/projectiles/pcaProj.py +++ b/thema/multiverse/system/outer/projectiles/pcaProj.py @@ -2,10 +2,15 @@ # Last Update: 05/15/24 # Updated by: JW +import logging from sklearn.decomposition import PCA from ..comet import Comet +# Configure module logger +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + def initialize(): """ @@ -73,6 +78,7 @@ def __init__(self, data_path, clean_path, dimensions, seed): self.dimensions = dimensions self.seed = seed self.projectionArray = None + def fit(self): """ diff --git a/thema/multiverse/system/outer/projectiles/tsneProj.py b/thema/multiverse/system/outer/projectiles/tsneProj.py index 624c6ff4..b5d64a65 100644 --- a/thema/multiverse/system/outer/projectiles/tsneProj.py +++ b/thema/multiverse/system/outer/projectiles/tsneProj.py @@ -2,10 +2,15 @@ # Last Update: 05/15/24 # Updated by: JW +import logging from sklearn.manifold import TSNE from ..comet import Comet +# Configure module logger +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + def initialize(): """ @@ -75,6 +80,7 @@ def __init__(self, data_path, clean_path, perplexity, dimensions, seed): self.perplexity = perplexity self.dimensions = dimensions self.seed = seed + def fit(self): """ diff --git a/thema/multiverse/universe/galaxy.py b/thema/multiverse/universe/galaxy.py index 14f242df..02d40c9d 100644 --- a/thema/multiverse/universe/galaxy.py +++ b/thema/multiverse/universe/galaxy.py @@ -1,15 +1,19 @@ # File: multiverse/universe.py -# Lasted Updated: 05/15/24 -# Updated By: JW +# Lasted Updated: 10/21/25 +# Updated By: SG import glob import importlib import itertools +import logging import os import pickle +from collections import Counter +import time +from typing import cast import numpy as np -import pandas as pd +import networkx as nx from omegaconf import OmegaConf from sklearn.cluster import AgglomerativeClustering from sklearn.manifold import MDS @@ -17,9 +21,16 @@ from .utils import starFilters, starSelectors from ... import config -from ...utils import create_file_name, function_scheduler +from ...utils import ( + create_file_name, + function_scheduler, + get_current_logging_config, +) from . import geodesics +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + class Galaxy: """ @@ -110,6 +121,7 @@ def __init__( metric="stellar_curvature_distance", selector="max_nodes", nReps=3, + filter_fn=None, YAML_PATH=None, verbose=False, ): @@ -135,6 +147,9 @@ def __init__( {"star0_name" : { "star0_parameter0":[list of star0_parameter0 values], "star0_parameter1": [list of star0_parameter1 values]}, "star1_name": {"star1_parameter0": [list of star1_parameter0 values]} } + filter_fn: str, callable, or None, optional + Filter function to apply to stars before distance calculations. + Can be a string name of a function in starFilters, a callable, or None for no filtering. YAML_PATH : str, optional The path to a YAML file containing configuration settings. Default is None. verbose: bool @@ -158,7 +173,7 @@ def __init__( metric = yamlParams.Galaxy.metric selector = yamlParams.Galaxy.selector nReps = yamlParams.Galaxy.nReps - filterFn = yamlParams.Galaxy.filter + filter_fn = yamlParams.Galaxy.get("filter", None) if type(yamlParams.Galaxy.stars) == str: stars = [yamlParams.Galaxy.stars] @@ -179,11 +194,14 @@ def __init__( self.cleanDir = cleanDir self.projDir = projDir self.outDir = outDir + self.YAML_PATH = YAML_PATH self.metric = metric self.selector = selector self.nReps = nReps - self.filterFn = filterFn + # Store YAML params for filter setup later (avoid pickling issues) + self._yaml_filter = filter_fn + self._yamlParams = yamlParams if YAML_PATH is not None else None self.keys = None self.distances = None @@ -211,6 +229,66 @@ def __init__( except Exception as e: print(e) + self.data = cast(str, self.data) + self.cleanDir = cast(str, self.cleanDir) + self.projDir = cast(str, self.projDir) + self.outDir = cast(str, self.outDir) + + def _setup_filter(self, yamlParams): + logger.info("Checking yaml for filter configuration.") + if yamlParams and yamlParams.Galaxy.get("filter"): + filter_type = yamlParams.Galaxy.get("filter") + if filter_type in config.filter_configs: + + filter_config = config.filter_configs[filter_type] + logger.info(f"Loading supported filter function: `{filter_type}`") + params = { + **filter_config["params"], + **yamlParams.Galaxy.get("filter_params", {}), + } + logger.info(f"Using filter parameters: {params}") + func = getattr(starFilters, filter_config["function"])(**params) + # Tag the callable with a human-friendly name for logging + try: + setattr(func, "_display_name", str(filter_type)) + except Exception: + pass + return func + + # Default to no-op filter with a stable display name + nf = starFilters.nofilterfunction + try: + setattr(nf, "_display_name", "nofilterfunction") + except Exception: + pass + return nf + + def _log_graph_distribution(self, files_to_use): + + out_dir = cast(str, self.outDir) + file_paths = [ + os.path.join(out_dir, f) for f in os.listdir(out_dir) if f.endswith(".pkl") + ] + component_counts = [] + + for file_path in file_paths: + try: + with open(file_path, "rb") as f: + star_obj = pickle.load(f) + if star_obj.starGraph and star_obj.starGraph.graph: + component_counts.append( + nx.number_connected_components(star_obj.starGraph.graph) + ) + except: + continue + + if component_counts: + counts = Counter(component_counts) + logger.debug("Component distribution:") + for n, count in sorted(counts.items()): + bar = "█" * count + logger.debug(f" {n:>2} components: {bar} ({count})") + def fit(self): """ Configure and generate space of Stars. @@ -223,6 +301,9 @@ def fit(self): Saves star objects to outDir and prints a count of failed saves. """ + # Get current logging config to pass to child processes + logging_config = get_current_logging_config() + subprocesses = [] for starName, starParamsDict in self.params.items(): @@ -232,14 +313,16 @@ def fit(self): star = module.initialize() # Load matching files - cleanfile_pattern = os.path.join(self.cleanDir, "*.pkl") + clean_dir = cast(str, self.cleanDir) + cleanfile_pattern = os.path.join(clean_dir, "*.pkl") valid_cleanFiles = glob.glob(cleanfile_pattern) - projfile_pattern = os.path.join(self.projDir, "*.pkl") + proj_dir = cast(str, self.projDir) + projfile_pattern = os.path.join(proj_dir, "*.pkl") valid_projFiles = glob.glob(projfile_pattern) for j, projFile in enumerate(valid_projFiles): - projFilePath = os.path.join(self.projDir, projFile) + projFilePath = os.path.join(proj_dir, projFile) with open(projFilePath, "rb") as f: cleanFile = pickle.load(f).get_clean_path() @@ -265,6 +348,7 @@ def fit(self): starParameters, starName, f"{k}_{j}", + logging_config, ) ) @@ -277,9 +361,8 @@ def fit(self): ) failed_saves = sum(1 for r in results if r is False) - print( - f"\n⭐️ {len(results)-failed_saves}({(len(results)- failed_saves)/len(results)*100}%) star objects successfully saved." - ) + if failed_saves > 0: + logger.warning(f"{failed_saves}/{len(results)} star saves failed") def _instantiate_star( self, @@ -290,6 +373,7 @@ def _instantiate_star( starParameters, starName, id, + logging_config, ): """Helper function for the fit() method. Creates a Star instances and fits it. @@ -309,25 +393,40 @@ def _instantiate_star( Name of star class id : int Identifier + logging_config : dict or None + Logging configuration from parent process Returns ------- - None + bool + True if saved successfully, False otherwise See Also -------- `Star` class and stars directory for more info on an individual fit. """ - my_star = star( - data_path=data_path, - clean_path=cleanFile, - projection_path=projFile, - **starParameters, - ) - my_star.fit() - output_file = create_file_name(starName, starParameters, id) - output_file = os.path.join(self.outDir, output_file) - return my_star.save(output_file) + # Configure logging in this child process + from ...utils import configure_child_process_logging + + configure_child_process_logging(logging_config) + + try: + my_star = star( + data_path=data_path, + clean_path=cleanFile, + projection_path=projFile, + **starParameters, + ) + my_star.fit() + output_file = create_file_name(starName, starParameters, id) + out_dir = cast(str, self.outDir) + output_file = os.path.join(out_dir, output_file) + return my_star.save(output_file) + except Exception as e: + logger.error( + f"Star {starName} #{id} failed - params: {starParameters}, error: {str(e)}" + ) + return False def collapse( self, @@ -335,75 +434,260 @@ def collapse( nReps=None, selector=None, filter_fn=None, + files: list | None = None, + distance_threshold: float | None = None, **kwargs, ): """ - Collapses the space of Stars into a small number of representative Stars + Collapses the space of Stars into representative Stars. + Either nReps (number of clusters) or distance_threshold (AgglomerativeClustering) can be used. Parameters ---------- - metric: str - metric used when comparing graphs. Currently, supported types are `stellar_curvature_distance` - and `stellar_kernel_distance`. - nReps: int - The number of representative stars - selector: str - The selection criteria to choose representatives from a cluster. Currently, only "random" supported. - **kwargs: - Arguments necessary for different metric functions. + metric : str, optional + Metric function name for comparing graphs. Defaults to self.metric. + nReps : int, optional + Number of clusters for AgglomerativeClustering. Ignored if distance_threshold is set. + selector : str, optional + Selection function name to choose representative stars. Defaults to self.selector. + filter_fn : callable, str, or None + Filter function to select a subset of graphs. Defaults to no filter. + files : list[str] or None + Optional list of file paths to process. Defaults to self.outDir. + distance_threshold : float, optional + AgglomerativeClustering distance threshold. Used if nReps is None. + **kwargs : + Additional arguments passed to the metric function. Returns ------- dict - A dictionary containing the path to the star and the size of the group it represents. + Mapping from cluster labels to selected stars and cluster sizes. """ + logger.info("Configuring Galaxy Collapse…") + metric = metric or self.metric + selector = selector or self.selector + # Set up filter when needed + + if callable(filter_fn): + logger.info( + f"Using provided filter function: {getattr(filter_fn, '__name__', str(type(filter_fn)))}" + ) + elif filter_fn is None: + filter_fn = self._setup_filter(self._yamlParams) - if metric is None: - metric = self.metric - if nReps is None: - nReps = self.nReps - if selector is None: - selector = self.selector - - if filter_fn is None: - filter_fn = self.filterFn + elif isinstance(filter_fn, str): + logger.info( + f"Function name provided, attempting to load from supported star filters: {filter_fn}" + ) + filter_callable = getattr( + starFilters, filter_fn, starFilters.nofilterfunction + ) + # Tag display name for logging + try: + setattr(filter_callable, "_display_name", str(filter_fn)) + except Exception: + pass + filter_fn = filter_callable + logger.info( + f"Loaded filter function: {getattr(filter_fn, '__name__', str(type(filter_fn)))}" + ) + else: + filter_fn = starFilters.nofilterfunction + try: + setattr(filter_fn, "_display_name", "nofilterfunction") + except Exception: + pass + logger.info(f"Defaulting to : {filter_fn.__name__}") + + if not callable(filter_fn): + raise ValueError( + f"filter_fn must be None, callable, or string, got {type(filter_fn)}" + ) metric_fn = getattr(geodesics, metric, geodesics.stellar_curvature_distance) selector_fn = getattr(starSelectors, selector, starSelectors.max_nodes) - # Handle the filter function - if filter_fn is None: - filter_fn = starFilters.nofilterfunction + # Filter/metric/selector names for readability + filter_fn_name = getattr( + filter_fn, + "_display_name", + getattr(filter_fn, "__name__", str(type(filter_fn))), + ) + logger.info( + f"Filter: {filter_fn_name} | Metric: {metric} | Selector: {selector}" + ) + + # Determine files to process + files_to_use = files if files is not None else self.outDir + + # Build a robust view of file list for logging (without changing behavior) + file_list: list[str] + out_dir = cast(str, self.outDir) + if files is None: + file_list = [ + os.path.join(out_dir, f) + for f in os.listdir(out_dir) + if f.endswith(".pkl") + ] else: - filter_fn = getattr(starFilters, filter_fn, starFilters.nofilterfunction) + if isinstance(files, (list, tuple)): + file_list = list(files) + elif isinstance(files, str) and os.path.isdir(files): + dir_str = cast(str, files) + file_list = [ + os.path.join(dir_str, f) + for f in os.listdir(dir_str) + if f.endswith(".pkl") + ] + else: + # Fallback: treat as a single path + file_list = [str(files)] + + total_files = len(file_list) + target_desc = ( + f"directory '{self.outDir}'" + if files is None + else f"{total_files} provided file(s)" + ) + logger.info(f"Scanning {total_files} candidate graph(s) from {target_desc}.") + + # Show graph distribution before filtering if DEBUG enabled + if logger.isEnabledFor(logging.DEBUG): + self._log_graph_distribution(files_to_use) + + # Determine concrete type to pass to metric function: either directory (str) or list[str] + out_dir: str = cast(str, self.outDir) + if files is None: + metric_files: str | list[str] = out_dir + else: + if isinstance(files, (list, tuple)): + metric_files = [str(f) for f in files] + elif isinstance(files, str) and os.path.isdir(files): + metric_files = files + else: + metric_files = [str(files)] + # Compute distances with timing + t0 = time.perf_counter() self.keys, self.distances = metric_fn( - files=self.outDir, filterfunction=filter_fn, **kwargs + files=metric_files, filterfunction=filter_fn, **kwargs ) + t1 = time.perf_counter() + filtered_count = len(self.keys) + logger.info( + f"Filter results: {filtered_count}/{total_files} graph(s) passed the filter in {t1 - t0:.2f}s" + ) + + # Distance matrix quick stats (off-diagonal) + try: + n = self.distances.shape[0] + if n == self.distances.shape[1] and n == filtered_count and n > 1: + mask = ~np.eye(n, dtype=bool) + dvals = self.distances[mask] + finite = np.isfinite(dvals) + if not np.all(finite): + bad = np.size(dvals) - np.count_nonzero(finite) + logger.warning( + f"Distance matrix contains {bad} non-finite value(s) (NaN/inf)." + ) + if np.any(finite): + dvals_f = dvals[finite] + logger.debug( + "Distance stats (off-diagonal, finite): min=%.4f | mean=%.4f | max=%.4f | count=%d", + float(np.min(dvals_f)), + float(np.mean(dvals_f)), + float(np.max(dvals_f)), + int(dvals_f.size), + ) + except Exception: + # Keep logging resilient + pass + + # Check if we have enough graphs for clustering + if filtered_count < 2: + raise ValueError( + f"Only {filtered_count} graph(s) passed the filter. " + "Clustering requires at least 2 graphs. " + "Consider relaxing your filter criteria." + ) + + # Use nReps or distance_threshold for AgglomerativeClustering + # Handle clustering configuration clarity + if nReps is None and distance_threshold is None: + nReps = self.nReps + + if nReps is not None and distance_threshold is not None: + logger.warning( + "Both nReps and distance_threshold provided; using distance_threshold and ignoring nReps." + ) + nReps = None + + # Check if nReps is valid for the number of filtered graphs + if nReps and nReps > filtered_count: + raise ValueError( + f"Cannot create {nReps} clusters from {filtered_count} graphs. " + f"Set nReps to {filtered_count} or fewer, or relax your filter." + ) + model = AgglomerativeClustering( metric="precomputed", linkage="average", compute_distances=True, - distance_threshold=None, n_clusters=nReps, + distance_threshold=distance_threshold, + ) + mode_desc = ( + f"n_clusters={nReps}" + if nReps is not None + else f"distance_threshold={distance_threshold}" ) + logger.info( + f"Clustering {filtered_count} graph(s) with AgglomerativeClustering ({mode_desc})…" + ) + t2 = time.perf_counter() model.fit(self.distances) + t3 = time.perf_counter() labels = model.labels_ - subgroups = {} + subgroups = {label: self.keys[labels == label] for label in set(labels)} - for label in labels: - mask = np.where(labels == label, True, False) - subkeys = self.keys[mask] - subgroups[label] = subkeys + # Log cluster size distribution + cluster_sizes = { + int(lbl): int(len(members)) for lbl, members in subgroups.items() + } + size_list = sorted(cluster_sizes.values(), reverse=True) + logger.info( + f"Formed {len(subgroups)} cluster(s) in {t3 - t2:.2f}s | sizes: {size_list}" + ) - for key in subgroups.keys(): - subgroup = subgroups[key] + self.selection = {} + for label, subgroup in subgroups.items(): selected_star = selector_fn(subgroup) - self.selection[key] = { + self.selection[label] = { "star": selected_star, "cluster_size": len(subgroup), } + # Keep detailed selection at DEBUG to avoid log spam + try: + star_name = os.path.basename(str(selected_star)) + except Exception: + star_name = str(selected_star) + logger.debug( + "Cluster %s: selected representative '%s' from %d member(s)", + str(label), + star_name, + len(subgroup), + ) + total_time = (t1 - t0) + (t3 - t2) + logger.info( + f"Galaxy Collapse complete: {len(self.selection)} representative model(s) selected " + f"({metric}, {mode_desc}). Total compute time ~{total_time:.2f}s" + ) + logger.info( + "Access results: this Galaxy's 'selection' maps cluster -> {'star','cluster_size'}. " + "If using a Thema instance, check its 'selected_model_files' for the chosen file paths." + ) return self.selection def get_galaxy_coordinates(self) -> np.ndarray: @@ -513,22 +797,24 @@ def writeParams_toYaml(self, YAML_PATH=None): None """ - if YAML_PATH is None and self.YAML_PATH is not None: - YAML_PATH = self.YAML_PATH - - if YAML_PATH is None and self.YAML_PATH is None: - raise ValueError("Please provide a valid filepath to YAML") + # Resolve yaml path to a non-None string for type checking + if YAML_PATH is None: + if self.YAML_PATH is None: + raise ValueError("Please provide a valid filepath to YAML") + yaml_path = cast(str, self.YAML_PATH) + else: + yaml_path = str(YAML_PATH) - if not os.path.isfile(YAML_PATH): + if not os.path.isfile(yaml_path): raise TypeError("File path does not point to a YAML file") - with open(YAML_PATH, "r") as f: + with open(yaml_path, "r") as f: params = OmegaConf.load(f) params.Galaxy = self.getParams()["params"] params.Galaxy.stars = list(self.getParams()["params"].keys()) - with open(YAML_PATH, "w") as f: + with open(yaml_path, "w") as f: OmegaConf.save(params, f) print("YAML file successfully updated") diff --git a/thema/multiverse/universe/geodesics.py b/thema/multiverse/universe/geodesics.py index f667a4b6..ec3c5919 100644 --- a/thema/multiverse/universe/geodesics.py +++ b/thema/multiverse/universe/geodesics.py @@ -1,61 +1,73 @@ # File: multiverse/universe/geodesics.py -# Lasted Updated: 07/29/25 -# Updated By: JW +# Lasted Updated: 10/21/25 +# Updated By: SG import os import pickle from typing import Callable -import networkx as nx import numpy as np +import networkx as nx from scott import Comparator from .utils.starFilters import nofilterfunction def stellar_curvature_distance( - files, + files: str | list, filterfunction: Callable | None = None, curvature="forman_curvature", vectorization="landscape", ): """ - Compute a pairwise distance matrix between graphs based on curvature filtrations. + Compute a pairwise distance matrix between graphs using curvature filtrations. Parameters ---------- - files : str - A path pointing to the directory containing starGraphs. + files : str or list[str] + Either a path to a directory containing starGraph files or a list of individual file paths. filterfunction : Callable, optional - A customizable filter function for pulling a subset of cosmic graphs. - Default is None. - kernel : str, optional - The kernel to be used for computing pairwise distances. - Default is "shortest_path". + A custom filter function to select a subset of cosmic graphs. Defaults to None. + curvature : str, optional + The curvature measure to use. Defaults to "forman_curvature". + + Supported values (increasing in complexity and computational intensity): + - "forman_curvature" : + A combinatorial measure based purely on local graph structure. + Fast to compute and suitable for large graphs or exploratory analysis. + - "balanced_forman_curvature" : + A refinement of Forman curvature that balances edge contributions, + improving sensitivity to degree heterogeneity while remaining efficient. + - "resistance_curvature" : + Derived from effective resistance distances between nodes. + Captures global connectivity patterns but is more computationally demanding. + - "ollivier_ricci_curvature" : + A transport-based curvature measure that reflects the geometry of + probabilistic mass movement between node neighborhoods. Provides the + most geometric insight but is the slowest to compute. + + For further details, see: + https://github.com/aidos-lab/curvature-filtrations/blob/main/notebooks/bagpipeline.ipynb + + vectorization : str, optional + Vectorization method for computing distances. Defaults to "landscape". Returns ------- - keys : np.array - A list of the keys for the models being compared. + keys : np.ndarray + Array of keys identifying the models being compared. distance_matrix : np.ndarray - A pairwise distance matrix between the persistence landscapes of the starGraphs. + Pairwise distance matrix between the persistence landscapes of the starGraphs. """ - starGraphs = _load_starGraphs(files, filterfunction) + + starGraphs = _load_starGraphs(files, graph_filter=filterfunction) keys = list(starGraphs.keys()) - # Convert starGraphs values to a list for indexed access starGraph_list = list(starGraphs.values()) - # Extract the actual networkx graphs graphs = [sg.graph for sg in starGraph_list] + mapped_graphs, _ = _map_string_nodes_to_integers(graphs) + C = Comparator(measure=curvature, weight="weight") - # Map string node IDs to integers for GUDHI compatibility - mapped_graphs, node_mapping = _map_string_nodes_to_integers(graphs) - - # Create a Curvature Comparator - C = Comparator( - measure=curvature, - weight="weight", - ) n = len(mapped_graphs) distance_matrix = np.zeros((n, n)) for i in range(n): @@ -71,48 +83,50 @@ def stellar_curvature_distance( return np.array(keys), distance_matrix -def _load_starGraphs(dir: str, graph_filter: Callable | None = None) -> dict: +def _load_starGraphs(dir: str | list, graph_filter: Callable | None = None) -> dict: """ - Load starGraphs in a given directory. This function only - returns diagrams for starGraphs that satisfy the constraint - given by `graph_filter`. + Load starGraphs from a directory or a list of pickle files. + Only returns starGraphs that satisfy the `graph_filter`. Parameters ---------- - dir : str - The directory containing the graphs, - from which diagrams can be extracted. + dir : str or list + Directory containing .pkl graphs, or a list of .pkl file paths. graph_filter : Callable, optional - Default to None (ie no filter). Only select graph object based on filter - function criteria (returns 1 to include and 0 to exclude) + Function that returns True for graphs to include. Defaults to nofilterfunction. Returns ------- dict - A dictionary mapping the graph object file paths - to the corresponding persistence diagram object. + Mapping of file path to starGraph object. """ - - assert os.path.isdir(dir), "Invalid graph Directory" - assert len(os.listdir(dir)) > 0, "Graph directory appears to be empty!" - if graph_filter is None: graph_filter = nofilterfunction + # Handle list vs directory + if isinstance(dir, list): + files = [str(f) for f in dir] + else: + assert os.path.isdir(dir), "Invalid graph Directory" + assert len(os.listdir(dir)) > 0, "Graph directory appears to be empty!" + files = [os.path.join(dir, f) for f in os.listdir(dir) if f.endswith(".pkl")] + + if not files: + raise ValueError("No .pkl files found to load.") + starGraphs = {} - for file in os.listdir(dir): - if file.endswith(".pkl"): - graph_file = os.path.join(dir, file) - with open(graph_file, "rb") as f: - graph_object = pickle.load(f) - - if graph_filter(graph_object): - if graph_object.starGraph is not None: - starGraphs[graph_file] = graph_object.starGraph - assert ( - len(starGraphs) > 0 - ), "You haven't produced any valid starGraphs. \ - Your filter function may be too stringent." + for graph_file in files: + with open(graph_file, "rb") as f: + graph_object = pickle.load(f) + + if graph_filter(graph_object): + if graph_object.starGraph is not None: + starGraphs[graph_file] = graph_object.starGraph + + if not starGraphs: + raise ValueError( + "No valid starGraphs produced. Your filter function may be too stringent." + ) return starGraphs @@ -136,23 +150,18 @@ def _map_string_nodes_to_integers(graphs): (mapped_graphs, node_mapping) where mapped_graphs have integer node IDs and node_mapping is the string->int mapping dict """ - # Collect all unique nodes across all graphs all_nodes = set() for graph in graphs: all_nodes.update(graph.nodes()) - # Create consistent mapping from string nodes to integers node_mapping = {node: i for i, node in enumerate(sorted(all_nodes))} - # Map all graphs to use integer node IDs mapped_graphs = [] for graph in graphs: - # Only remap if we have non-integer nodes if any(not isinstance(node, int) for node in graph.nodes()): mapped_graph = nx.relabel_nodes(graph, node_mapping) mapped_graphs.append(mapped_graph) else: - # Graph already has integer nodes mapped_graphs.append(graph.copy()) return mapped_graphs, node_mapping diff --git a/thema/multiverse/universe/star.py b/thema/multiverse/universe/star.py index d725a7f6..12b57e69 100644 --- a/thema/multiverse/universe/star.py +++ b/thema/multiverse/universe/star.py @@ -2,12 +2,17 @@ # Last Update: 05/15/24 # Updated by: JW +import logging import os import pickle from abc import abstractmethod from ...core import Core +# Configure module logger +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + class Star(Core): """ @@ -86,5 +91,6 @@ def save(self, file_path, force=False): return True else: return False - except Exception: + except Exception as e: + logger.error(f"Failed to save Star to {file_path}: {str(e)}") return False diff --git a/thema/multiverse/universe/stars/jmapStar.py b/thema/multiverse/universe/stars/jmapStar.py index a5d65389..e0970323 100644 --- a/thema/multiverse/universe/stars/jmapStar.py +++ b/thema/multiverse/universe/stars/jmapStar.py @@ -2,8 +2,8 @@ # Last Update: 05/15/24 # Updated by: JW - import itertools +import logging from collections import defaultdict import networkx as nx @@ -14,6 +14,10 @@ from ..star import Star from ..utils.starGraph import starGraph +# Configure module logger +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + def initialize(): """ @@ -128,6 +132,14 @@ def __init__( self.minIntersection = minIntersection self.clusterer = get_clusterer(clusterer) self.mapper = KeplerMapper() + + # Store parameters for potential debugging + self._params = { + "nCubes": nCubes, + "percOverlap": percOverlap, + "minIntersection": minIntersection, + "clusterer": clusterer + } def fit(self): """Computes a kmapper complex based on the configuration parameters and @@ -151,6 +163,16 @@ def fit(self): cover=Cover(self.nCubes, self.percOverlap), clusterer=self.clusterer, ) + + if not self.complex or "nodes" not in self.complex: + logger.debug( + f"KeplerMapper produced empty complex - params: {self._params}, " + f"projection shape: {self.projection.shape}" + ) + self.complex = None + self.starGraph = None + return + self.nodes = convert_keys_to_alphabet(self.complex["nodes"]) graph = nx.Graph() @@ -160,7 +182,12 @@ def fit(self): edges = nerve.compute(self.nodes) if len(edges) == 0: - self.starGraph = None + # Log when we get empty graphs - this is important for debugging + logger.debug( + f"No edges found in graph - params: {self._params}, " + f"nodes: {len(self.nodes)}, projection shape: {self.projection.shape}" + ) + self.starGraph = starGraph(graph) # Create empty graph instead of None else: graph.add_nodes_from(self.nodes) nx.set_node_attributes(graph, self.nodes, "membership") @@ -169,10 +196,14 @@ def fit(self): graph.add_weighted_edges_from(edges) else: graph.add_edges_from(edges) + + self.starGraph = starGraph(graph) - self.starGraph = starGraph(graph) - - except: + except Exception as e: + logger.error( + f"jmapStar.fit() failed with params: {self._params}, " + f"projection shape: {self.projection.shape}, error: {str(e)}" + ) self.complex = None self.starGraph = None diff --git a/thema/multiverse/universe/utils/starFilters.py b/thema/multiverse/universe/utils/starFilters.py index 09b2af47..c0bf3fdd 100644 --- a/thema/multiverse/universe/utils/starFilters.py +++ b/thema/multiverse/universe/utils/starFilters.py @@ -1,2 +1,135 @@ -def nofilterfunction(graphobject): +from typing import Callable +import networkx as nx + + +def nofilterfunction(graphobject) -> int: + """Default filter that accepts all graph objects.""" return 1 + + +def component_count_filter(target_components: int) -> Callable: + """Filter for graphs with specific number of connected components. + + Args: + target_components: Desired number of connected components + + Returns: + Filter function that returns 1 for matching graphs, 0 otherwise + + Example: + >>> filter_func = component_count_filter(4) + >>> galaxy.collapse(filter_fn=filter_func) + """ + + def _filter(graphobject) -> int: + if graphobject.starGraph is None: + return 0 + return ( + 1 + if nx.number_connected_components(graphobject.starGraph.graph) + == target_components + else 0 + ) + + return _filter + + +def component_count_range_filter( + min_components: int, max_components: int +) -> Callable: + """Filter for graphs within component count range. + + Args: + min_components: Minimum number of components (inclusive) + max_components: Maximum number of components (inclusive) + + Returns: + Filter function that returns 1 for graphs in range, 0 otherwise + """ + + def _filter(graphobject) -> int: + if graphobject.starGraph is None: + return 0 + n_components = nx.number_connected_components( + graphobject.starGraph.graph + ) + return 1 if min_components <= n_components <= max_components else 0 + + return _filter + + +def minimum_nodes_filter(min_nodes: int) -> Callable: + """Filter for graphs with minimum number of nodes. + + Args: + min_nodes: Minimum number of nodes required + + Returns: + Filter function that returns 1 for graphs meeting criteria, 0 otherwise + """ + + def _filter(graphobject) -> int: + if graphobject.starGraph is None: + return 0 + return ( + 1 + if graphobject.starGraph.graph.number_of_nodes() >= min_nodes + else 0 + ) + + return _filter + + +def minimum_edges_filter(min_edges: int) -> Callable: + """Filter for graphs with minimum number of edges. + + Args: + min_edges: Minimum number of edges required + + Returns: + Filter function that returns 1 for graphs meeting criteria, 0 otherwise + """ + + def _filter(graphobject) -> int: + if graphobject.starGraph is None: + return 0 + return ( + 1 + if graphobject.starGraph.graph.number_of_edges() >= min_edges + else 0 + ) + + return _filter + + +def minimum_unique_items_filter(min_unique_items: int) -> Callable: + """Filter for graphs with minimum number of unique items across all nodes. + + This filter counts the total number of unique data points present + across all nodes in the Mapper graph, ensuring no double-counting + of items that appear in multiple nodes. + + Args: + min_unique_items: Minimum number of unique items required + + Returns: + Filter function that returns 1 for graphs meeting criteria, 0 otherwise + + Example: + >>> filter_func = minimum_unique_items_filter(100) + >>> galaxy.collapse(filter_fn=filter_func) + """ + + def _filter(graphobject) -> int: + if graphobject.starGraph is None: + return 0 + + # Collect all unique items from node membership lists + unique_items = set() + for node in graphobject.starGraph.graph.nodes(): + membership = graphobject.starGraph.graph.nodes[node]["membership"] + unique_items.update(membership) + + return 1 if len(unique_items) >= min_unique_items else 0 + + return _filter diff --git a/thema/thema.py b/thema/thema.py index 408b1adb..5426c530 100644 --- a/thema/thema.py +++ b/thema/thema.py @@ -7,6 +7,7 @@ """ import os +import networkx as nx from omegaconf import OmegaConf from thema.multiverse import Planet, Oort, Galaxy @@ -40,6 +41,12 @@ class Thema: -------- >>> thema = Thema('params.yaml') >>> thema.genesis() # Run the full pipeline + >>> + >>> # Access representative stars after genesis + >>> selection = thema.galaxy.selection + >>> print(f"Selected {len(selection)} representative stars") + >>> for cluster_id, info in selection.items(): + ... print(f"Cluster {cluster_id}: {info['star']} ({info['cluster_size']} stars)") """ def __init__(self, YAML_PATH): @@ -66,6 +73,7 @@ def __init__(self, YAML_PATH): self.clean_files = None self.projection_files = None self.model_files = None + self.selected_model_files = None def genesis(self): """ @@ -86,6 +94,10 @@ def genesis(self): -------- >>> thema = Thema('params.yaml') >>> thema.genesis() + >>> + >>> # Representative stars are automatically selected and stored in galaxy.selection + >>> selected_files = [info['star'] for info in thema.galaxy.selection.values()] + >>> print(f"Representative files: {selected_files}") """ self.spaghettify_innerSystem() self.innerSystem_genesis() @@ -289,6 +301,10 @@ def galaxy_genesis(self): >>> thema = Thema('params.yaml') >>> thema.spaghettify_galaxy() # First clean the directory >>> thema.galaxy_genesis() # Process the data + >>> + >>> # Representative stars are stored in galaxy.selection after collapse + >>> representative_files = [info['star'] for info in thema.galaxy.selection.values()] + >>> print(f"Found {len(representative_files)} representative stars") """ model_outdir = os.path.join( self.params.outDir, self.params.runName + "/models/" @@ -304,6 +320,10 @@ def galaxy_genesis(self): self.model_files = [ model_outdir + file for file in os.listdir(model_outdir) ] + self.galaxy.collapse() + self.selected_model_files = [ + str(x["star"]) for x in self.galaxy.selection.values() + ] def spaghettify_galaxy(self): """ diff --git a/thema/utils.py b/thema/utils.py index 60838a06..aafd0b32 100644 --- a/thema/utils.py +++ b/thema/utils.py @@ -1,5 +1,5 @@ # File: thema/utils.py -# Last Update: 05/15/24 +# Last Update: 10/16/25 # Updated by: JW import os @@ -8,7 +8,17 @@ from concurrent.futures import ProcessPoolExecutor, as_completed import pandas as pd -from tqdm import tqdm +import logging + +try: + from IPython import get_ipython + + if get_ipython() is not None and "IPKernelApp" in get_ipython().config: + from tqdm.notebook import tqdm + else: + from tqdm import tqdm +except (ImportError, AttributeError): + from tqdm import tqdm def function_scheduler( @@ -60,7 +70,7 @@ def function_scheduler( total=len(functions), desc="Progress", unit="function", - dynamic_ncols=True, + # dynamic_ncols=True, ) outcomes = [] @@ -183,3 +193,64 @@ def sanitize(value): filename = "_".join(parts) + ".pkl" return filename + + +def get_current_logging_config(): + """ + Get current logging configuration state. + Used by multiprocessing functions to replicate logging config in child processes. + + Returns + ------- + dict or None + Dictionary with logging config if enabled, None if disabled. + """ + thema_logger = logging.getLogger("thema") + + if thema_logger.handlers and not all( + isinstance(h, logging.NullHandler) for h in thema_logger.handlers + ): + return {"level": thema_logger.getEffectiveLevel(), "enabled": True} + else: + return None + + +def configure_child_process_logging(config): + """ + Configure logging in child processes. + + Parameters + ---------- + config : dict or None + Logging configuration from parent process. + """ + if config is None: + return + + try: + thema_logger = logging.getLogger("thema") + + if not thema_logger.handlers or all( + isinstance(h, logging.NullHandler) for h in thema_logger.handlers + ): + + for handler in thema_logger.handlers[:]: + if isinstance(handler, logging.NullHandler): + thema_logger.removeHandler(handler) + + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + thema_logger.addHandler(handler) + thema_logger.setLevel(config["level"]) + thema_logger.propagate = False + + # Reset child module loggers + for name in list(logging.Logger.manager.loggerDict.keys()): + if name.startswith("thema."): + child_module_logger = logging.getLogger(name) + child_module_logger.setLevel(logging.NOTSET) + except (AttributeError, KeyError, ValueError) as e: + pass