diff --git a/.github/workflows/test_package.yml b/.github/workflows/test_package.yml index d04345d..b2ab4bb 100644 --- a/.github/workflows/test_package.yml +++ b/.github/workflows/test_package.yml @@ -22,10 +22,14 @@ jobs: uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} + - name: Set up R + uses: r-lib/actions/setup-r@v2 + with: + r-version: '4.3.2' - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install -r requirements_dev.txt - name: Test with pytest run: | - python -m pytest + python -m pytest --cov=causalAssembly tests/ diff --git a/.lintr b/.lintr new file mode 100644 index 0000000..8c59a64 --- /dev/null +++ b/.lintr @@ -0,0 +1,5 @@ +linters: linters_with_defaults( + line_length_linter = line_length_linter(120L), + object_name_linter = NULL, + indentation_linter = NULL + ) diff --git a/Makefile b/Makefile index 612ffb9..d531305 100644 --- a/Makefile +++ b/Makefile @@ -1,28 +1,50 @@ #Makefile for causalAssembly + MAKEFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) CURRENT_DIR := $(notdir $(patsubst %/,%,$(dir $(MAKEFILE_PATH)))) VENV_NAME := venv_${CURRENT_DIR} PYTHON=${VENV_NAME}/bin/python +ALLOWED_LICENSES := "$(shell tr -s '\n' ';' < allowed_licenses.txt)" +ALLOWED_PACKAGES := $(shell tr -s '\n' ' ' < allowed_packages.txt) + sync-venv: : # Create or update default virtual environment to latest pinned : # dependencies test -d $(VENV_NAME) || \ python3.10 -m virtualenv $(VENV_NAME); \ - ${PYTHON} -m pip install -U pip; \ - ${PYTHON} -m pip install pip-tools - . $(VENV_NAME)/bin/activate && pip-sync requirements_dev.txt - #. $(VENV_NAME)/bin/activate && pip install --no-deps -e . + ${PYTHON} -m pip install -U uv; \ + . $(VENV_NAME)/bin/activate && uv pip sync requirements_dev.txt + . $(VENV_NAME)/bin/activate && uv pip install --no-deps -e . requirements: : # Update requirements_dev.txt if only new library is added : # Assumes virtual environment with pip-tools installed is activated - pip-compile --extra dev -o requirements_dev.txt pyproject.toml --annotation-style line --no-emit-index-url --no-emit-trusted-host --allow-unsafe --resolver=backtracking + uv pip compile pyproject.toml --output-file=requirements.txt --annotation-style=line + uv pip compile pyproject.toml --extra=dev --output-file=requirements_dev.txt --annotation-style=line update-requirements: : # Update requirements_dev.txt if dependencies should be updated : # Assumes virtual environment with pip-tools installed is activated - pip-compile --extra dev -o requirements_dev.txt pyproject.toml --annotation-style line --no-emit-index-url --no-emit-trusted-host --allow-unsafe --resolver=backtracking --upgrade + uv pip compile pyproject.toml --output-file=requirements.txt --annotation-style=line --upgrade + uv pip compile pyproject.toml --extra=dev --output-file=requirements_dev.txt --annotation-style=line --upgrade + + +precommit: + : # Run precommit on all files locally (this runs only the pre-commit and not the pre-push hooks) + pre-commit install + pre-commit run --all-files + +test: + : # Run pytest + python -m pytest + +licenses: + : # Create file with used licenses and check if these are permissive + pip-licenses --order=license --ignore-packages ${ALLOWED_PACKAGES} --allow-only=${ALLOWED_LICENSES} > licenses.txt + +clean-branches: + git branch --merged | grep -v '\\*\\|main\\|develop' | xargs -n 1 -r git branch -d diff --git a/README.md b/README.md index 3585308..30e976f 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ This repo provides details regarding $\texttt{causalAssembly}$, a causal discovery benchmark data tool based on complex production data. Theoretical details and information regarding construction are presented in the [paper](https://arxiv.org/abs/2306.10816): - Göbler, K., Windisch, T., Pychynski, T., Sonntag, S., Roth, M., & Drton, M. causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery, to appear in Proceedings of the 3rd Conference on Causal Learning and Reasoning (CLeaR), 2024, + Göbler, K., Windisch, T., Pychynski, T., Sonntag, S., Roth, M., & Drton, M. causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery, to appear in Proceedings of the 3rd Conference on Causal Learning and Reasoning (CLeaR), 2024, ## Authors * [Konstantin Goebler](mailto:konstantin.goebler@de.bosch.com) * [Steffen Sonntag](mailto:steffen.sonntag@de.bosch.com) diff --git a/VERSION b/VERSION index 781dcb0..26aaba0 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.1.3 +1.2.0 diff --git a/allowed_licenses.txt b/allowed_licenses.txt new file mode 100644 index 0000000..dc8290e --- /dev/null +++ b/allowed_licenses.txt @@ -0,0 +1,24 @@ +Apache-2.0 +Apache 2.0 +Apache License 2.0 +APACHE SOFTWARE LICENSE +BOSCH-INTERNAL +BSD 3-Clause +BSD License +BSD +3-Clause BSD License +BSD (3-Clause) +ISC +ISC License (ISCL) +Other/Proprietary License +MIT +MIT License +CMU License (MIT-CMU) +GPL-2.0-or-later +GNU General Public License v2 or later (GPLv2+) +Python Software Foundation License +The Unlicense (Unlicense) +Historical Permission Notice and Disclaimer (HPND) +MPL-2.0 +NVIDIA Proprietary Software +UNKNOWN diff --git a/allowed_packages.txt b/allowed_packages.txt new file mode 100644 index 0000000..f7aa7de --- /dev/null +++ b/allowed_packages.txt @@ -0,0 +1,8 @@ +causalAssembly +rpy2 +chardet +certifi +dbx +mdpcommonlib +pathspec +text-unidecode diff --git a/benchmarks/run.py b/benchmarks/run.py index eccbc36..cce3de5 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify diff --git a/benchmarks/utils.py b/benchmarks/utils.py index adbfc43..f68b1aa 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -12,11 +13,13 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ + import logging from collections import defaultdict from concurrent.futures import ProcessPoolExecutor from functools import partial from itertools import repeat +from typing import Any import lingam import networkx as nx @@ -40,12 +43,15 @@ class BenchMarker: - """Class to run causal discovery benchmarks. One instance can be + """Class to run causal discovery benchmarks. + + One instance can be the basis for several benchmark runs based on the same settings with different data generator objects. """ def __init__(self): + """Initializes the BenchMarker class.""" self.algorithms: dict = {"snr": BenchMarker._fit_snr} self.collect_results: dict = {} self.num_runs: int @@ -54,37 +60,37 @@ def __init__(self): self.n_select: int def include_pc(self): - """Includes the PC-stable algorithm from the `causal-learn` package""" + """Includes the PC-stable algorithm from the `causal-learn` package.""" logger.info("PC algorithm added to benchmark routines.") self.algorithms["pc"] = BenchMarker._fit_pc def include_ges(self): - """Includes GES from the `causal-learn` package""" + """Includes GES from the `causal-learn` package.""" logger.info("GES algorithm added to benchmark routines.") self.algorithms["ges"] = BenchMarker._fit_ges def include_notears(self): - """Includes the NOTEARS algorithm from the `gcastle` package""" + """Includes the NOTEARS algorithm from the `gcastle` package.""" logger.info("NOTEARS added to benchmark routines.") self.algorithms["notears"] = BenchMarker._fit_notears def include_grandag(self): - """Includes the Gran-DAG algorithm from the `gcastle` package""" + """Includes the Gran-DAG algorithm from the `gcastle` package.""" logger.info("Gran-DAG added to benchmark routines.") self.algorithms["grandag"] = BenchMarker._fit_grandag def include_score(self): - """Includes the SCORE algorithm from the `dodiscovery` package""" + """Includes the SCORE algorithm from the `dodiscovery` package.""" logger.info("SCORE algorithm added to benchmark routines.") self.algorithms["score"] = BenchMarker._fit_score def include_das(self): - """Includes the DAS algorithm from the `dodiscovery` package""" + """Includes the DAS algorithm from the `dodiscovery` package.""" logger.info("DAS algorithm added to benchmark routines.") self.algorithms["das"] = BenchMarker._fit_das def include_lingam(self): - """Includes the DirectLiNGAM algorithm from the `lingam` package""" + """Includes the DirectLiNGAM algorithm from the `lingam` package.""" logger.info("Direct LiNGAM added to benchmark routines.") self.algorithms["lingam"] = BenchMarker._fit_lingam @@ -112,7 +118,7 @@ def _causallearn2amat(causal_learn_graph: np.ndarray) -> np.ndarray: if causal_learn_graph[row, col] == -1 and causal_learn_graph[col, row] == -1: amat[row, col] = amat[col, row] = 1 if causal_learn_graph[row, col] == 1 and causal_learn_graph[col, row] == 1: - logger.warning(f"ambiguity found in {(row,col)}. I'll make it bidirected") + logger.warning(f"ambiguity found in {(row, col)}. I'll make it bidirected") amat[row, col] = amat[col, row] = 1 return amat @@ -131,9 +137,12 @@ def _fit_ges(data: pd.DataFrame) -> PDAG: # Beware of the Simulated DAG! Causal Discovery Benchmarks May Be Easy To Game. @staticmethod def _fit_snr(data: pd.DataFrame) -> pd.DataFrame: - """Take n x d data, order nodes by marginal variance and + """SNR algo. + + Take n x d data, order nodes by marginal variance and regresses each node onto those with lower variance, using - edge coefficients as structure estimates.""" + edge coefficients as structure estimates. + """ X = data.to_numpy() LR = LinearRegression() LL = LassoLarsIC(criterion="bic") @@ -158,10 +167,13 @@ def _fit_snr(data: pd.DataFrame) -> pd.DataFrame: # Beware of the Simulated DAG! Causal Discovery Benchmarks May Be Easy To Game. @staticmethod def varsortability(data: pd.DataFrame, ground_truth: pd.DataFrame, tol=1e-9): - """Takes n x d data and a d x d adjaceny matrix, + """Varsortability algo. + + Takes n x d data and a d x d adjaceny matrix, where the i,j-th entry corresponds to the edge weight for i->j, and returns a value indicating how well the variance order - reflects the causal order.""" + reflects the causal order. + """ X = data.to_numpy() W = ground_truth.to_numpy() E = W != 0 @@ -226,7 +238,7 @@ def _fit_das(data: pd.DataFrame) -> pd.DataFrame: def run_benchmark( self, runs: int, - prod_obj: ProductionLineGraph | ProcessCell, + prod_obj: ProductionLineGraph, n_select: int = 500, harmonize_via: str | None = "cpdag_transform", size_threshold: int = 50, @@ -252,11 +264,24 @@ def run_benchmark( `"best_dag_shd"` is selected, all DAGs in the implied MEC will be enumerated, the SHD calculated and the lowest (best) candidate DAG chosen. Defaults to "cpdag_transform". + size_threshold (int) : size of threshold. parallelize (bool, optional): Whether to run on parallel processes. Defaults to False. n_workers (int, optional): If `parallelize = True` you need to assign the number of workers to prarallelize over. Defaults to 4. seed_sequence (int, optional): If `parallelize = True` you may choose the seed sequence handed down to every parallel process. Defaults to 1234. + chunksize (int | None): If `parallelize = True` you may choose the + chunksize for the parallelization. If `None`, it will be set to + `runs / n_workers` or 1, whichever is larger. Defaults to None. + external_dfs (list[pd.DataFrame | np.ndarray] | None, optional): + If you want to use external dataframes for the benchmark runs, you can pass a list + of dataframes or numpy arrays here. The length of the list must match the number of + runs. If `None`, the `prod_obj` will be used to sample data. + Defaults to None. + between_and_within_results (bool, optional): If `True`, the benchmark will also + return the within and between metrics for the `prod_obj` if it is a + `ProductionLineGraph`. If `False`, only the metrics for the overall graph will be + returned. Defaults to False. """ self.num_runs = runs self.prod_object = prod_obj @@ -359,9 +384,9 @@ def run_benchmark( @staticmethod def single_run( - new_seed: np.random.BitGenerator | None, - prod_obj: ProductionLineGraph | ProcessCell, - algorithms: list[str], + new_seed: int | None, + prod_obj: ProductionLineGraph, + algorithms: dict[str, Any], child_seed: None | np.random.SeedSequence = None, harmonize_via: None | str = "cpdag_transform", n_select: int = 500, @@ -369,29 +394,37 @@ def single_run( external_df: pd.DataFrame | np.ndarray | None = None, between_and_within_results: bool = False, ): - """Single benchmark run + """Single benchmark run. Args: - new_seed (np.random.BitGenerator | None): seed - prod_obj (ProductionLineGraph | ProcessCell): object of interest - algorithms (list[str]): cd algs - child_seed (None | np.random.SeedSequence, optional): Defaults to None. - harmonize_via (None | str, optional): Defaults to "cpdag_transform". - n_select (int, optional): Defaults to 500. - size_threshold (int, optional): Defaults to 50. + new_seed (int | None): _description_ + prod_obj (ProductionLineGraph): _description_ + algorithms (dict[str, Any]): _description_ + child_seed (None | np.random.SeedSequence, optional): _description_. Defaults to None. + harmonize_via (None | str, optional): _description_. Defaults to "cpdag_transform". + n_select (int, optional): _description_. Defaults to 500. + size_threshold (int, optional): _description_. Defaults to 50. + external_df (pd.DataFrame | np.ndarray | None, optional): + _description_. Defaults to None. + between_and_within_results (bool, optional): _description_. Defaults to False. Raises: - AssertionError - AssertionError - TypeError + AssertionError: _description_ + AssertionError: _description_ + TypeError: _description_ + AssertionError: _description_ Returns: - dict: results from that run. + _type_: _description_ """ metrics = defaultdict(partial(defaultdict, list)) within_between_metrics = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - if child_seed: - prod_obj.random_state = np.random.default_rng(seed=child_seed[new_seed]) + if child_seed is not None: + num_seeds = 10 + child_seeds = child_seed.spawn(num_seeds) + if new_seed is None: + new_seed = 0 # Or some default + prod_obj.random_state = np.random.default_rng(child_seeds[new_seed]) if external_df is not None: if isinstance(external_df, pd.DataFrame): @@ -428,7 +461,9 @@ def single_run( dag_metrics = DAGmetrics(truth=prod_obj.graph, est=dag) all_shds.append(dag_metrics._shd()) - absolute_distance_to_mean = np.abs(all_shds - np.mean(all_shds)) + absolute_distance_to_mean = np.abs( + np.array(all_shds) - np.mean(all_shds) + ) random_index_choice = np.random.choice( np.flatnonzero( absolute_distance_to_mean == np.min(absolute_distance_to_mean) @@ -444,18 +479,17 @@ def single_run( ground_truth, pd.DataFrame ): raise AssertionError("something went wrong in the best DAG selection") - else: - if type(ground_truth) is not type(result): - raise TypeError("ground truth and results need to have the same instance.") + elif type(ground_truth) is not type(result): + raise TypeError("ground truth and results need to have the same instance.") get_metrics = DAGmetrics(truth=ground_truth, est=result) - get_metrics.collect_metrics() + my_metrics = get_metrics.collect_metrics() if harmonize_via == "best_dag_shd": target_dag = prod_obj.graph result_dag = nx.from_pandas_adjacency(df=result, create_using=nx.DiGraph) - get_metrics.metrics["sid"] = int(SID(target=target_dag, pred=result_dag)) + my_metrics["sid"] = int(SID(target=target_dag, pred=result_dag)) if harmonize_via == "cpdag_transform": - get_metrics.metrics["shd"] = get_metrics._shd(count_anticausal_twice=False) + my_metrics["shd"] = get_metrics._shd(count_anticausal_twice=False) if between_and_within_results and harmonize_via == "best_dag_shd": if isinstance(prod_obj, ProcessCell): @@ -473,24 +507,24 @@ def single_run( get_within_metrics = DAGmetrics( truth=within_metrics, est=result_plg.within_adjacency ) - get_within_metrics.collect_metrics() + within_metrics = get_within_metrics.collect_metrics() get_between_metrics = DAGmetrics( truth=between_metrics, est=result_plg.between_adjacency ) - get_between_metrics.collect_metrics() + between_metrics = get_between_metrics.collect_metrics() # make dict - w_b_dict = {"within": get_within_metrics, "between": get_between_metrics} + w_b_dict = {"within": within_metrics, "between": between_metrics} for metric_name in ["precision", "recall"]: for which_one, dct in w_b_dict.items(): within_between_metrics[alg_name][metric_name][which_one].append( - dct.metrics[metric_name] + dct[metric_name] ) - for metric_name, _ in get_metrics.metrics.items(): - metrics[alg_name][metric_name].append(get_metrics.metrics[metric_name]) + for metric_name, _ in my_metrics.items(): + metrics[alg_name][metric_name].append(my_metrics[metric_name]) if between_and_within_results: return { @@ -503,7 +537,7 @@ def single_run( def causallearn2amat(causal_learn_graph: np.ndarray) -> np.ndarray: - """Causallearn object helper function + """Causallearn object helper function. Args: causal_learn_graph (np.ndarray): causal lean output graph @@ -522,8 +556,7 @@ def causallearn2amat(causal_learn_graph: np.ndarray) -> np.ndarray: def to_dag_amat(cpdag_amat: pd.DataFrame) -> pd.DataFrame: - """Turns PDAG into random member of the corresponding - Markov equivalence class. + """Turns PDAG into random member of the corresponding Markov equivalence class. Args: cpdag_amat (pd.DataFrame): PDAG representing the MEC @@ -531,7 +564,6 @@ def to_dag_amat(cpdag_amat: pd.DataFrame) -> pd.DataFrame: Returns: pd.DataFrame: DAG as member of MEC. """ - pdag = PDAG.from_pandas_adjacency(cpdag_amat) chosen_dag = pdag.to_dag() if not nx.is_directed_acyclic_graph(chosen_dag): diff --git a/causalAssembly/__init__.py b/causalAssembly/__init__.py index 03bfac7..bf2a89f 100644 --- a/causalAssembly/__init__.py +++ b/causalAssembly/__init__.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify diff --git a/causalAssembly/dag.py b/causalAssembly/dag.py index 513ec79..6be1ca9 100644 --- a/causalAssembly/dag.py +++ b/causalAssembly/dag.py @@ -1,4 +1,19 @@ -"""DAG class""" +"""DAG class. + +Copyright (c) 2023 Robert Bosch GmbH + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published +by the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . +""" + from __future__ import annotations import logging @@ -19,16 +34,22 @@ class DAG: - """ - General class for dealing with directed acyclic graph i.e. + """General class for dealing with directed acyclic graph i.e. + graphs that are directed and must not contain any cycles. """ def __init__( self, - nodes: list | None = None, - edges: list[tuple] | None = None, + nodes: list[str] | list[int] | set[int | str] | None = None, + edges: list[tuple[str | int, str | int]] | set[tuple[str | int, str | int]] | None = None, ): + """Initialize DAG. + + Args: + nodes (list | None, optional): _description_. Defaults to None. + edges (list[tuple] | None, optional): _description_. Defaults to None. + """ if nodes is None: nodes = [] if edges is None: @@ -50,7 +71,6 @@ def _add_edge(self, i, j): self._edges.add((i, j)) # Check if graph is acyclic - # TODO: Make check really after each edge is added? if not self.is_acyclic(): raise ValueError( "The edge set you provided \ @@ -63,6 +83,11 @@ def _add_edge(self, i, j): @property def random_state(self): + """Random state. + + Returns: + _type_: _description_ + """ return self._random_state @random_state.setter @@ -72,7 +97,7 @@ def random_state(self, r: np.random.Generator): self._random_state = r def add_edge(self, edge: tuple[str, str]): - """Add edge to DAG + """Add edge to DAG. Args: edge (tuple[str, str]): Edge to add @@ -80,7 +105,7 @@ def add_edge(self, edge: tuple[str, str]): self._add_edge(*edge) def add_edges_from(self, edges: list[tuple[str, str]]): - """Add multiple edges to DAG + """Add multiple edges to DAG. Args: edges (list[tuple[str, str]]): Edges to add @@ -129,8 +154,7 @@ def induced_subgraph(self, nodes: list[str]) -> DAG: return DAG(nodes=nodes, edges=edges) def is_adjacent(self, i: str, j: str) -> bool: - """Return True if the graph contains an directed - edge between i and j. + """Return True if the graph contains an directed edge between i and j. Args: i (str): node i. @@ -142,9 +166,7 @@ def is_adjacent(self, i: str, j: str) -> bool: return (j, i) in self.edges or (i, j) in self.edges def is_clique(self, potential_clique: set) -> bool: - """ - Check every pair of node X potential_clique is adjacent. - """ + """Check every pair of node X potential_clique is adjacent.""" return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2)) def is_acyclic(self) -> bool: @@ -167,7 +189,7 @@ def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> DAG: DAG """ assert pd_amat.shape[0] == pd_amat.shape[1] - nodes = pd_amat.columns + nodes = list(pd_amat.columns) all_connections = [] start, end = np.where(pd_amat != 0) @@ -182,7 +204,7 @@ def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> DAG: return DAG(nodes=nodes, edges=dir_edges) def remove_edge(self, i: str, j: str): - """Removes edge in question + """Removes edge in question. Args: i (str): tail @@ -199,10 +221,10 @@ def remove_edge(self, i: str, j: str): self._parents[j].discard(i) def remove_node(self, node): - """Remove a node from the graph""" + """Remove a node from the graph.""" self._nodes.remove(node) - self._edges = {(i, j) for i, j in self._edges if i != node and j != node} + self._edges = {(i, j) for i, j in self._edges if node not in (i, j)} for child in self._children[node]: self._parents[child].remove(node) @@ -215,8 +237,9 @@ def remove_node(self, node): @property def adjacency_matrix(self) -> pd.DataFrame: - """Returns adjacency matrix where the i,jth - entry being one indicates that there is an edge + """Returns adjacency matrix. + + The i,jth entry being one indicates that there is an edge from i to j. A zero indicates that there is no edge. Returns: @@ -232,7 +255,7 @@ def adjacency_matrix(self) -> pd.DataFrame: return amat def vstructs(self) -> set: - """Retrieve v-structures + """Retrieve v-structures. Returns: set: set of all v-structures @@ -246,7 +269,7 @@ def vstructs(self) -> set: return vstructures def copy(self): - """Return a copy of the graph""" + """Return a copy of the graph.""" return DAG(nodes=self._nodes, edges=self._edges) def show(self): @@ -287,8 +310,7 @@ def num_nodes(self) -> int: @property def num_edges(self) -> int: - """Number of directed edges - in current DAG. + """Number of directed edges in current DAG. Returns: int: Number of directed edges @@ -297,7 +319,7 @@ def num_edges(self) -> int: @property def sparsity(self) -> float: - """Sparsity of the graph + """Sparsity of the graph. Returns: float: in [0,1] @@ -307,8 +329,7 @@ def sparsity(self) -> float: @property def edges(self) -> list[tuple]: - """Gives all directed edges in - current DAG. + """Gives all directed edges in current DAG. Returns: list[tuple]: List of directed edges. @@ -318,6 +339,7 @@ def edges(self) -> list[tuple]: @property def causal_order(self) -> list[str]: """Returns the causal order of the current graph. + Note that this order is in general not unique. Returns: @@ -360,7 +382,7 @@ def from_nx(cls, nx_dag: nx.DiGraph) -> DAG: raise TypeError("DAG must be of type nx.DiGraph") return DAG(nodes=list(nx_dag.nodes), edges=list(nx_dag.edges)) - def save_drf(self, filename: str, location: str = None): + def save_drf(self, filename: str, location: str | Path | None = None): """Writes a drf dict to file. Please provide the .pkl suffix! Args: @@ -368,8 +390,7 @@ def save_drf(self, filename: str, location: str = None): location (str, optional): path to file in case it's not located in the current working directory. Defaults to None. """ - - if not location: + if location is None: location = Path().resolve() location_path = Path(location, filename) @@ -391,10 +412,15 @@ def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame: return _sample_from_drf(graph=self, size=size, smoothed=smoothed) def to_cpdag(self) -> PDAG: + """Conversion to CPDAG. + + Returns: + PDAG: _description_ + """ return dag2cpdag(dag=self.to_networkx()) @classmethod - def load_drf(cls, filename: str, location: str = None) -> dict: + def load_drf(cls, filename: str, location: str | Path | None = None) -> dict: """Loads a drf dict from a .pkl file into the workspace. Args: diff --git a/causalAssembly/dag_utils.py b/causalAssembly/dag_utils.py index 144ba9c..29c0954 100644 --- a/causalAssembly/dag_utils.py +++ b/causalAssembly/dag_utils.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -12,6 +13,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ + import itertools import logging @@ -43,7 +45,6 @@ def merge_dags( Returns: nx.DiGraph: merged DAG """ - for old_node_name, new_node_name in mapping.items(): if new_node_name not in target_dag.nodes(): raise ValueError(f"{new_node_name} does not exist in target_dag") @@ -74,7 +75,7 @@ def merge_dags( def merge_dags_via_edges( left_dag: nx.DiGraph, right_dag: nx.DiGraph, - edges: list[tuple] = None, + edges: list[tuple] | None = None, isolate_target_nodes: bool = False, ): """Merges two dags via a list of edges. @@ -119,7 +120,6 @@ def merge_dags_via_edges( merged_dag = nx.compose(left_dag, right_dag) - # TODO experimental merged_dag.add_edges_from(edges, **{"connector": True}) return merged_dag @@ -143,7 +143,6 @@ def tuples_from_cartesian_product(l1: list, l2: list) -> list[tuple]: [(0,'a'), (0,'b'), (0,'c'), (1,'a'), (1,'b'), (1,'c'), (2,'a'), (2,'b'), (2,'c')] """ - # TODO: This could take a long time for large graphs... return [ (tail, head) for tail, head in itertools.product( @@ -153,11 +152,15 @@ def tuples_from_cartesian_product(l1: list, l2: list) -> list[tuple]: ] -def _bootstrap_sample(rng: np.random.Generator, data: np.array, size: int = None) -> np.array: +def _bootstrap_sample( + rng: np.random.Generator, data: np.ndarray, size: int | None = None +) -> np.ndarray: """Generate bootstrap sample, i.e. + random sample with replacement of length `size` from 1-d array. Args: + rng (np.random.Generator): Random number generator data (np.array): 1-d array size (int, optional): Size of bootstrap sample. If set to `None`, we set size to length of input array diff --git a/causalAssembly/drf_fitting.py b/causalAssembly/drf_fitting.py index f22acc8..c57e2e7 100644 --- a/causalAssembly/drf_fitting.py +++ b/causalAssembly/drf_fitting.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -12,6 +13,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ + from __future__ import annotations import numpy as np @@ -32,16 +34,18 @@ class DRF: - """Wrapper around the corresponding R package: + """Wrapper around the corresponding R package. + Distributional Random Forests (Cevid et al., 2020). - Closely adopted from their python wrapper.""" + Closely adopted from their python wrapper. + """ def __init__(self, **fit_params): + """Initialize DRF object.""" self.fit_params = fit_params - def fit(self, X: pd.DataFrame, Y: pd.DataFrame): - """Fit DRF in order to estimate conditional - distribution P(Y|X=x). + def fit(self, X: pd.DataFrame, Y: pd.DataFrame | pd.Series): + """Fit DRF in order to estimate conditional distribution P(Y|X=x). Args: X (pd.DataFrame): Conditioning set. @@ -88,9 +92,10 @@ def produce_sample( def fit_drf(graph: ProductionLineGraph | ProcessCell | DAG, data: pd.DataFrame): - """Fit distributional random forests to the - factorization implied by the current graph + """Fit distributional random forests to the factorization implied by the current graph. + Args: + graph (ProductionLineGraph | ProcessCell | DAG): Graph to fit the DRF to. data (pd.DataFrame): Columns of dataframe need to match name and order of the graph Raises: diff --git a/causalAssembly/metrics.py b/causalAssembly/metrics.py index 0aeabb8..940d93f 100644 --- a/causalAssembly/metrics.py +++ b/causalAssembly/metrics.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -12,6 +13,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ + import copy import networkx as nx @@ -21,6 +23,7 @@ class DAGmetrics: """Class to calculate performance metrics for DAGs. + Make sure that the ground truth and the estimated DAG have the same order of rows/columns. If these objects are nx.DiGraphs, make sure that graph.nodes() have the same oder or pass a new nodelist to the class when initiating. The @@ -33,8 +36,19 @@ def __init__( self, truth: nx.DiGraph | pd.DataFrame | np.ndarray, est: nx.DiGraph | pd.DataFrame | np.ndarray, - nodelist: list = None, + nodelist: list[str] | None = None, ): + """Inits the DAGmetrics class. + + Args: + truth (nx.DiGraph | pd.DataFrame | np.ndarray): _description_ + est (nx.DiGraph | pd.DataFrame | np.ndarray): _description_ + nodelist (list, optional): _description_. Defaults to None. + + Raises: + TypeError: _description_ + TypeError: _description_ + """ if not isinstance(truth, nx.DiGraph | pd.DataFrame | np.ndarray): raise TypeError("Ground truth graph has to be one of the permitted classes.") @@ -47,7 +61,7 @@ def __init__( self.metrics = None def _calculate_scores(self): - """Calculate Precision, Recall and F1 and g score + """Calculate Precision, Recall and F1 and g score. Return: precision: float @@ -59,8 +73,9 @@ def _calculate_scores(self): gscore: float max(0, (TP-FP))/(TP+FN) """ + TWO = 2 assert self.est.shape == self.truth.shape and self.est.shape[0] == self.est.shape[1] - TP = np.where((self.est + self.truth) == 2, 1, 0).sum(axis=1).sum() + TP = np.where((self.est + self.truth) == TWO, 1, 0).sum(axis=1).sum() TP_FP = self.est.sum(axis=1).sum() FP = TP_FP - TP TP_FN = self.truth.sum(axis=1).sum() @@ -96,12 +111,13 @@ def collect_metrics(self) -> dict[str, float | int]: metrics = self._calculate_scores() metrics["shd"] = self._shd() self.metrics = metrics + return metrics @classmethod def _convert_to_numpy( cls, graph: nx.DiGraph | pd.DataFrame | np.ndarray, - nodelist: list = None, + nodelist: list[str] | None = None, ): if isinstance(graph, np.ndarray): return copy.deepcopy(graph) diff --git a/causalAssembly/models_dag.py b/causalAssembly/models_dag.py index 57bf614..d4d2c3d 100644 --- a/causalAssembly/models_dag.py +++ b/causalAssembly/models_dag.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -12,6 +13,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ + from __future__ import annotations import itertools @@ -45,6 +47,8 @@ @dataclass class NodeAttributes: + """Node Attributes.""" + ALLOW_IN_EDGES = "allow_in_edges" HIDDEN = "is_hidden" @@ -64,8 +68,12 @@ def _sample_from_drf( size=size, seed=prod_object.random_state )[0] else: + if prod_object.random_state is not None: + rng = prod_object.random_state + else: + rng = np.random.default_rng() sample_dict[node] = _bootstrap_sample( - rng=prod_object.random_state, + rng=rng, data=prod_object.drf[node].dataset[0], size=size, ) @@ -154,8 +162,8 @@ def _interventional_sample_from_drf( class ProcessCell: - """ - Representation of a single Production Line Cell + """Representation of a single Production Line Cell. + (to model a station / a process in a production line environment). @@ -168,6 +176,11 @@ class ProcessCell: """ def __init__(self, name: str): + """Inits Process cell class. + + Args: + name (str): _description_ + """ self.name = name self.graph: nx.DiGraph = nx.DiGraph() @@ -201,7 +214,7 @@ def edges(self) -> list[tuple]: @property def num_nodes(self) -> int: - """Number of nodes in the graph + """Number of nodes in the graph. Returns: int @@ -210,7 +223,7 @@ def num_nodes(self) -> int: @property def num_edges(self) -> int: - """Number of edges in the graph + """Number of edges in the graph. Returns: int @@ -219,7 +232,7 @@ def num_edges(self) -> int: @property def sparsity(self) -> float: - """Sparsity of the graph + """Sparsity of the graph. Returns: float: in [0,1] @@ -229,8 +242,7 @@ def sparsity(self) -> float: @property def ground_truth(self) -> pd.DataFrame: - """Returns the current ground truth as - pandas adjacency. + """Returns the current ground truth as pandas adjacency. Returns: pd.DataFrame: Adjacenccy matrix. @@ -240,6 +252,7 @@ def ground_truth(self) -> pd.DataFrame: @property def causal_order(self) -> list[str]: """Returns the causal order of the current graph. + Note that this order is in general not unique. Returns: @@ -258,7 +271,7 @@ def parents(self, of_node: str) -> list[str]: """ return list(self.graph.predecessors(of_node)) - def save_drf(self, filename: str, location: str = None): + def save_drf(self, filename: str, location: str | Path | None = None): """Writes a drf dict to file. Please provide the .pkl suffix! Args: @@ -266,7 +279,6 @@ def save_drf(self, filename: str, location: str = None): location (str, optional): path to file in case it's not located in the current working directory. Defaults to None. """ - if not location: location = Path().resolve() @@ -281,7 +293,7 @@ def add_module( allow_in_edges: bool = True, mark_hidden: bool | list = False, ) -> str: - """Adds module to cell graph. Module has to be as nx.DiGraph object + """Adds module to cell graph. Module has to be as nx.DiGraph object. Args: graph (nx.DiGraph): Graph to add to cell. @@ -296,15 +308,13 @@ def add_module( Returns: str: prefix of Module created """ - next_module_prefix = self.next_module_prefix() node_renaming_dict = { old_node_name: f"{self.name}_{next_module_prefix}_{old_node_name}" for old_node_name in graph.nodes() } - - self.modules[self.next_module_prefix()] = graph.copy() + self.modules[self.next_module_prefix()] = graph.copy() # type: ignore graph = nx.relabel_nodes(graph, node_renaming_dict) if allow_in_edges: # for later: mark nodes to not have incoming edges @@ -323,14 +333,14 @@ def add_module( nx.set_node_attributes( graph, values=overwrite_dict ) # only overwrite the ones specified - # TODO relabel attributes, i.e. name of the parents has changed now? - # .update_attributes or so or keep and remove prefixes in bayesian network creation? self.graph = nx.compose(self.graph, graph) return next_module_prefix def input_cellgraph_directly(self, graph: nx.DiGraph, allow_in_edges: bool = False): - """Allow to input graphs on a cell-level. This should only be done if the graph + """Allow to input graphs on a cell-level. + + This should only be done if the graph is already available for the entire cell, otherwise `add_module()` is preferred. Args: @@ -362,23 +372,10 @@ def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame: """ return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed) - def interventional_sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame: - """Draw from the trained DRF. - - Args: - size (int, optional): Number of samples to be drawn. Defaults to 10. - smoothed (bool, optional): If set to true, marginal distributions will - be sampled from smoothed bootstraps. Defaults to True. - - Returns: - pd.DataFrame: Data frame that follows the distribution implied by the ground truth. - """ - return _interventional_sample_from_drf(prod_object=self, size=size, smoothed=smoothed) - def _generate_random_dag(self, n_nodes: int = 5, p: float = 0.1) -> nx.DiGraph: - """ - Creates a random DAG by - taking an arbitrary ordering of the specified number of nodes, + """Creates a random DAG. + + By taking an arbitrary ordering of the specified number of nodes, and then considers edges from node i to j only if i < j. That constraint leads to DAGness by construction. @@ -389,26 +386,37 @@ def _generate_random_dag(self, n_nodes: int = 5, p: float = 0.1) -> nx.DiGraph: Returns: nx.DiGraph: """ + rng = self.random_state + if rng is None: + rng = np.random.default_rng() dag = nx.DiGraph() dag.add_nodes_from(range(0, n_nodes)) causal_order = list(dag.nodes) - self.random_state.shuffle(causal_order) + rng.shuffle(causal_order) all_forward_edges = itertools.combinations(causal_order, 2) edges = np.array(list(all_forward_edges)) - random_choice = self.random_state.choice([False, True], p=[1 - p, p], size=edges.shape[0]) + random_choice = rng.choice([False, True], p=[1 - p, p], size=edges.shape[0]) dag.add_edges_from(edges[random_choice]) return dag def add_random_module(self, n_nodes: int = 7, p: float = 0.10): + """Add random module to the cell. + + Args: + n_nodes (int, optional): _description_. Defaults to 7. + p (float, optional): _description_. Defaults to 0.10. + """ randomdag = self._generate_random_dag(n_nodes=n_nodes, p=p) self.add_module(graph=randomdag, allow_in_edges=True, mark_hidden=False) def connect_by_module(self, m1: str, m2: str, edges: list[tuple]): - """Connect two modules (by name, e.g. M2, M4) of the cell by a list + """Connect two modules. + + (by name, e.g. M2, M4) of the cell by a list of edges with the original node names. Args: @@ -431,8 +439,9 @@ def connect_by_module(self, m1: str, m2: str, edges: list[tuple]): self.graph.add_edges_from(new_edges) def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph: - """ - Add random edges to graph according to proportion, + """Add random edges to graph. + + according to proportion with restriction specified in node attributes. Args: @@ -445,7 +454,9 @@ def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph: Returns: nx.DiGraph: DAG with new edges added. """ - + rng = self.random_state + if rng is None: + rng = np.random.default_rng() arrow_head_candidates = get_arrow_head_candidates_from_graph( graph=self.graph, node_attributes_to_filter=NodeAttributes.ALLOW_IN_EDGES ) @@ -460,9 +471,7 @@ def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph: ### choose edges uniformly according to sparsity parameter chosen_edges = [ potential_edges[i] - for i in self.random_state.choice( - a=len(potential_edges), size=num_choices, replace=False - ) + for i in rng.choice(a=len(potential_edges), size=num_choices, replace=False) ] self.graph.add_edges_from(chosen_edges) @@ -474,24 +483,33 @@ def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph: return self.graph def __repr__(self): + """Repr method. + + Returns: + _type_: _description_ + """ return f"ProcessCell(name={self.name})" def __str__(self): + """Str method. + + Returns: + _type_: _description_ + """ cell_description = { "Cell Name: ": self.name, "Description:": self.description if self.description else "n.a.", "Modules:": self.no_of_modules, "Nodes: ": self.num_nodes, } - s = str() + s = "" for info, info_text in cell_description.items(): s += f"{info:<14}{info_text:>5}\n" return s def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]): - """Check whether all starting point nodes - (first value in edge tuple) are allowed. + """Check whether all starting point nodes (first value in edge tuple) are allowed. Args: m1 (str): Module1 @@ -504,8 +522,8 @@ def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]): """ source_nodes = set([e[0] for e in edges]) target_nodes = set([e[1] for e in edges]) - m1_nodes = set(self.modules.get(m1).nodes()) - m2_nodes = set(self.modules.get(m2).nodes()) + m1_nodes = set(self.modules.get(m1).nodes()) # type: ignore + m2_nodes = set(self.modules.get(m2).nodes()) # type: ignore if not source_nodes.issubset(m1_nodes): raise ValueError(f"source nodes: {source_nodes} not include in {m1}s nodes: {m1_nodes}") @@ -514,6 +532,7 @@ def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]): def next_module_prefix(self) -> str: """Return the next module prefix, e.g. + if there are already 3 modules connected to the cell, will return module_prefix4 @@ -524,6 +543,11 @@ def next_module_prefix(self) -> str: @property def module_prefix(self) -> str: + """Module prefix. + + Returns: + str: _description_ + """ return self.__module_prefix @module_prefix.setter @@ -535,12 +559,19 @@ def module_prefix(self, module_prefix: str): @property def no_of_modules(self) -> int: - return len(self.modules) + """Number of modules in the cell. - def get_nodes_by_attribute(self, attr_name: str, submodule: str = None) -> list: - pass + Returns: + int: _description_ + """ + return len(self.modules) def get_available_attributes(self): + """Get available attributes of the nodes in the graph. + + Returns: + _type_: _description_ + """ available_attributes = set() for node_tuple in self.graph.nodes(data=True): for attribute_name in node_tuple[1].keys(): @@ -549,13 +580,20 @@ def get_available_attributes(self): return list(available_attributes) def to_cpdag(self) -> PDAG: + """To CPDAG conversion. + + Returns: + PDAG: _description_ + """ return dag2cpdag(dag=self.graph) def show( self, meta_desc: str = "", ): - """Plots the cell graph by giving extra weight to nodes + """Plots the cell graph. + + by giving extra weight to nodes with high in- and out-degree. Args: @@ -583,11 +621,12 @@ def show( vmin=-0.2, vmax=1, node_color=[ - (d + 10) / (max_in_degree + 10) for _, d in self.graph.in_degree(self.nodes) + (d + 10) / (max_in_degree + 10) + for _, d in self.graph.in_degree(self.nodes) # type: ignore ], node_size=[ 500 * (d + 1) / (max_out_degree + 1) for _, d in self.graph.out_degree(self.nodes) - ], + ], # type: ignore ) nx.draw_networkx_edges( @@ -620,15 +659,22 @@ def _plot_cellgraph( with_box=True, meta_desc="", ): - """Plots the cell graph by giving extra weight to nodes + """Plots the cell graph. + + by giving extra weight to nodes with high in- and out-degree. Args: - with_edges (bool, optional): Defaults to True. - with_box (bool, optional): Defaults to True. - meta_desc (str, optional): Defaults to "". - center (_type_, optional): Defaults to np.array([0, 0]). - fig_size (tuple, optional): Defaults to (2, 8). + ax (_type_): _description_ + node_color (_type_): _description_ + node_size (_type_): _description_ + center (_type_, optional): _description_. Defaults to np.array([0, 0]). + with_edges (bool, optional): _description_. Defaults to True. + with_box (bool, optional): _description_. Defaults to True. + meta_desc (str, optional): _description_. Defaults to "". + + Returns: + _type_: _description_ """ cmap = plt.get_cmap("cividis") @@ -674,7 +720,7 @@ def _plot_cellgraph( PatchCollection( [ FancyBboxPatch( - center - [2, 1], + center - [2, 1], # type: ignore 4, 2.6, boxstyle=BoxStyle("Round", pad=0.02), @@ -695,7 +741,8 @@ def choose_edges_from_cells_randomly( probability: float, rng: np.random.Generator, ) -> list[tuple[str, str]]: - """ + """Choose cells randomly. + From two given cells (graphs), we take the cartesian product (end up with from_cell.number_of_nodes x to_cell.number_of_nodes possible edges (node tuples). @@ -708,12 +755,13 @@ def choose_edges_from_cells_randomly( from_cell: ProcessCell from where we want the edges to_cell: ProcessCell to where we want the edges probability: between 0 and 1 + rng (np.random.Generator): Random number generator to use. Returns: list[tuple[str, str]]: Chosen edges. """ - - assert 0 <= probability <= 1.0 + ONE = 1.0 + assert 0 <= probability <= ONE arrow_tail_candidates = list(from_cell.graph.nodes) arrow_head_candidates = get_arrow_head_candidates_from_graph(graph=to_cell.graph) @@ -736,6 +784,7 @@ def get_arrow_head_candidates_from_graph( graph: nx.DiGraph, node_attributes_to_filter: str = NodeAttributes.ALLOW_IN_EDGES ) -> list[str]: """Returns all arrow head (nodes where an arrow points to) nodes as list of candidates. + To later build a list of tuples of potential edges. Args: @@ -791,6 +840,27 @@ class ProductionLineGraph: """ def __init__(self): + """Inits ProductionLineGraph. + + Raises: + AssertionError: _description_ + TypeError: _description_ + AssertionError: _description_ + AssertionError: _description_ + AssertionError: _description_ + ValueError: _description_ + AssertionError: _description_ + AssertionError: _description_ + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + TypeError: _description_ + AssertionError: _description_ + AttributeError: _description_ + + Returns: + _type_: _description_ + """ self._random_state = np.random.default_rng(seed=2023) self.cells: dict[str, ProcessCell] = dict() self.cell_prefix = "C" @@ -803,6 +873,11 @@ def __init__(self): @property def random_state(self): + """Random state. + + Returns: + _type_: _description_ + """ return self._random_state @random_state.setter @@ -816,8 +891,7 @@ def __init_mutilated_dag(self): @property def graph(self) -> nx.DiGraph: - """ - Returns a nx.DiGraph object of the actual graph. + """Returns a nx.DiGraph object of the actual graph. The graph is only built HERE, i.e. all ProcessCells exist standalone in self.cells, with no connections between their nodes yet. @@ -867,7 +941,7 @@ def edges(self) -> list[tuple]: @property def num_nodes(self) -> int: - """Number of nodes in the graph + """Number of nodes in the graph. Returns: int @@ -876,7 +950,7 @@ def num_nodes(self) -> int: @property def num_edges(self) -> int: - """Number of edges in the graph + """Number of edges in the graph. Returns: int @@ -885,7 +959,7 @@ def num_edges(self) -> int: @property def sparsity(self) -> float: - """Sparsity of the graph + """Sparsity of the graph. Returns: float: in [0,1] @@ -895,8 +969,7 @@ def sparsity(self) -> float: @property def ground_truth(self) -> pd.DataFrame: - """Returns the current ground truth as - pandas adjacency. + """Returns the current ground truth as pandas adjacency. Returns: pd.DataFrame: Adjacenccy matrix. @@ -913,8 +986,7 @@ def _get_union_graph(self) -> nx.DiGraph: @property def within_adjacency(self) -> pd.DataFrame: - """Returns adjacency matrix ignoring all - between-cell edges. + """Returns adjacency matrix ignoring all between-cell edges. Returns: pd.DataFrame: adjacency matrix @@ -924,8 +996,7 @@ def within_adjacency(self) -> pd.DataFrame: @property def between_adjacency(self) -> pd.DataFrame: - """Returns adjacency matrix ignoring all - within-cell edges. + """Returns adjacency matrix ignoring all within-cell edges. Returns: pd.DataFrame: adjacency matrix @@ -936,6 +1007,7 @@ def between_adjacency(self) -> pd.DataFrame: @property def causal_order(self) -> list[str]: """Returns the causal order of the current graph. + Note that this order is in general not unique. Returns: @@ -955,6 +1027,11 @@ def parents(self, of_node: str) -> list[str]: return list(self.graph.predecessors(of_node)) def to_cpdag(self) -> PDAG: + """Convert to CPDAG. + + Returns: + PDAG: _description_ + """ return dag2cpdag(dag=self.graph) def get_nodes_of_station(self, station_name: str) -> list: @@ -990,7 +1067,7 @@ def __add_cell(self, cell: ProcessCell) -> ProcessCell: raise ValueError(f"A cell with name: {cell.name} is already in the Production Line.") - def new_cell(self, name: str = None, is_eol: bool = False) -> ProcessCell: + def new_cell(self, name: str | None = None, is_eol: bool = False) -> ProcessCell: """Add a new cell to the production line. If no name is given, cell name is given by counting available cells + 1 @@ -1009,7 +1086,7 @@ def new_cell(self, name: str = None, is_eol: bool = False) -> ProcessCell: actual_no_of_cells = len(self.cells.values()) c = ProcessCell(name=f"{self.cell_prefix}{actual_no_of_cells}") - c.random_state = self.random_state + c.random_state = self.random_state # type: ignore c.is_eol = is_eol self.__add_cell(cell=c) @@ -1048,7 +1125,7 @@ def connect_cells( rng=self.random_state, ) - prob_it += 1 # FIXME: a bit ugly and hard to read + prob_it += 1 self.cell_connector_edges.extend(chosen_edges) if eol_cell := self.eol_cell: @@ -1066,8 +1143,7 @@ def connect_cells( self.cell_connector_edges.extend(chosen_eol_edges) def copy(self) -> ProductionLineGraph: - """Makes a full copy of the current - ProductionLineGraph object + """Makes a full copy of the current ProductionLineGraph object. Returns: ProductionLineGraph: copyied object. @@ -1088,14 +1164,18 @@ def copy(self) -> ProductionLineGraph: return copy_graph def connect_across_cells_manually(self, edges: list[tuple]): - """Add edges manually across cells. You need to give the full name + """Add edges manually across cells. + + You need to give the full name Args: edges (list[tuple]): list of edges to add """ self.cell_connector_edges.extend(edges) def intervene_on(self, nodes_values: dict[str, RandomSymbol | float]): - """Specify hard or soft intervention. If you want to intervene + """Specify hard or soft intervention. + + If you want to intervene upon more than one node provide a list of nodes to intervene on and a list of corresponding values to set these nodes to. (see example). The mutilated dag will automatically be @@ -1126,15 +1206,15 @@ def intervene_on(self, nodes_values: dict[str, RandomSymbol | float]): mutilated_dag.remove_edges_from(edges_to_remove) drf_replace[node] = value - self.mutilated_dags[ - f"do({list(nodes_values.keys())})" - ] = mutilated_dag # specifiying the same set twice will override + self.mutilated_dags[f"do({list(nodes_values.keys())})"] = ( + mutilated_dag # specifiying the same set twice will override + ) self.interventional_drf[f"do({list(nodes_values.keys())})"] = drf_replace @property def interventions(self) -> list: - """Returns all interventions performed on the original graph + """Returns all interventions performed on the original graph. Returns: list: list of intervened upon nodes in do(x) notation. @@ -1168,13 +1248,13 @@ def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame: @classmethod def get_ground_truth(cls) -> ProductionLineGraph: - """Loads in the ground_truth as described in the paper: + """Loads in the ground_truth as described in the paper. + causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery Returns: ProductionLineGraph: ground_truth for cells and line. """ - gt_response = requests.get(DATA_GROUND_TRUTH, timeout=5) ground_truth = json.loads(gt_response.text) @@ -1201,7 +1281,8 @@ def get_ground_truth(cls) -> ProductionLineGraph: @classmethod def get_data(cls) -> pd.DataFrame: - """Load in semi-synthetic data as described in the paper: + """Load in semi-synthetic data as described in the paper. + causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery Returns: @@ -1211,7 +1292,9 @@ def get_data(cls) -> pd.DataFrame: @classmethod def from_nx(cls, g: nx.DiGraph, cell_mapper: dict[str, list]): - """Convert nx.DiGraph to ProductionLineGraph. Requires a dict mapping + """Convert nx.DiGraph to ProductionLineGraph. + + Requires a dict mapping where keys are cell names and values correspond to nodes within these cells. Args: @@ -1240,7 +1323,7 @@ def from_nx(cls, g: nx.DiGraph, cell_mapper: dict[str, list]): return pline @classmethod - def load_drf(cls, filename: str, location: str = None): + def load_drf(cls, filename: str, location: str | Path | None = None): """Loads a drf dict from a .pkl file into the workspace. Args: @@ -1262,7 +1345,19 @@ def load_drf(cls, filename: str, location: str = None): return pickle_drf @classmethod - def load_pline_from_pickle(cls, filename: str, location: str = None): + def load_pline_from_pickle(cls, filename: str, location: str | Path | None = None): + """Load production line graph from a pickle file. + + Args: + filename (str): _description_ + location (str | Path | None, optional): _description_. Defaults to None. + + Raises: + TypeError: _description_ + + Returns: + _type_: _description_ + """ if not location: location = Path().resolve() @@ -1276,7 +1371,7 @@ def load_pline_from_pickle(cls, filename: str, location: str = None): return pickle_line - def save_drf(self, filename: str, location: str = None): + def save_drf(self, filename: str, location: str | Path | None = None): """Writes a drf dict to file. Please provide the .pkl suffix! Args: @@ -1284,7 +1379,6 @@ def save_drf(self, filename: str, location: str = None): location (str, optional): path to file in case it's not located in the current working directory. Defaults to None. """ - if not location: location = Path().resolve() @@ -1328,7 +1422,7 @@ def sample_from_interventional_drf( ) def hidden_nodes(self) -> list: - """Returns list of nodes marked as hidden + """Returns list of nodes marked as hidden. Returns: list: of hidden nodes @@ -1340,16 +1434,19 @@ def hidden_nodes(self) -> list: ] def visible_nodes(self): + """All visible nodes in the graph. + + Returns: + _type_: _description_ + """ return [node for node in self.nodes if node not in self.hidden_nodes()] @property def eol_cell(self) -> ProcessCell | None: - """ + """Returns ProcessCell. - Returns ProcessCell: the EOL cell + the EOL cell (if any single cell has attr .is_eol = True), otherwise returns None - ------- - """ for cell in self.cells.values(): if cell.is_eol: @@ -1358,6 +1455,7 @@ def eol_cell(self) -> ProcessCell | None: @property def ground_truth_visible(self) -> pd.DataFrame: """Generates a ground truth graph in form of a pandas adjacency matrix. + Row and column names correspond to visible. The following integers can occur: @@ -1368,7 +1466,6 @@ def ground_truth_visible(self) -> pd.DataFrame: Returns: pd.DataFrame: amat with values in {0,1,2}. """ - if len(self.hidden_nodes()) == 0: return self.ground_truth else: @@ -1385,8 +1482,7 @@ def ground_truth_visible(self) -> pd.DataFrame: # reverse = lambda tuples: tuples[::-1] def reverse(tuples): - """ - Simple function to reverse tuple order + """Simple function to reverse tuple order. Args: tuples (tuple): tuple to reverse order @@ -1404,7 +1500,7 @@ def reverse(tuples): return amat_visible def show(self, meta_description: list | None = None, fig_size: tuple = (15, 8)): - """Plot full assembly line + """Plot full assembly line. Args: meta_description (list | None, optional): Specify additional cell info. @@ -1458,27 +1554,51 @@ def show(self, meta_description: list | None = None, fig_size: tuple = (15, 8)): ) def __str__(self): + """String method for ProductionLineGraph. + + Returns: + _type_: _description_ + """ s = "ProductionLine\n\n" for cell in self.cells: s += f"{cell}\n" return s def __getattr__(self, attrname): + """Get a cell by its name. + + Args: + attrname (_type_): _description_ + + Raises: + AttributeError: _description_ + + Returns: + _type_: _description_ + """ if attrname not in self.cells.keys(): raise AttributeError(f"{attrname} is not a valid attribute (cell name?)") return self.cells[attrname] - # https://docs.python.org/3/library/pickle.html#pickle-protocol - # TODO why is .cells enough, are the other member vars directly pickable? def __getstate__(self): + """Get current state of the ProductionLineGraph. + + Returns: + _type_: _description_ + """ return (self.__dict__, self.cells) def __setstate__(self, state): + """Set state of the ProductionLineGraph. + + Args: + state (_type_): _description_ + """ self.__dict__, self.cells = state @classmethod def via_cell_number(cls, n_cells: int, cell_prefix: str = "C"): - """Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3 + """Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3. Will create empty C0, C1 and C2 as cells if no other cell_prefix is given. @@ -1496,8 +1616,7 @@ def via_cell_number(cls, n_cells: int, cell_prefix: str = "C"): return pl def _pairs_with_hidden_mediators(self): - """ - Return pairs of nodes with hidden mediators present. + """Return pairs of nodes with hidden mediators present. Args: graph (nx.DiGraph): DAG @@ -1506,6 +1625,7 @@ def _pairs_with_hidden_mediators(self): Returns: list: list of tuples with pairs of nodes with hidden mediator """ + TWO = 2 any_paths = [] visible = self.visible_nodes() hidden_all = self.hidden_nodes() @@ -1517,14 +1637,15 @@ def _pairs_with_hidden_mediators(self): any_paths.append(path) pairs_with_hidden_mediators = [ - (ls[0], ls[-1]) for ls in any_paths if np.all(np.isin(ls[1:-1], hidden)) and len(ls) > 2 + (ls[0], ls[-1]) + for ls in any_paths + if np.all(np.isin(ls[1:-1], hidden)) and len(ls) > TWO ] return pairs_with_hidden_mediators def _pairs_with_hidden_confounders(self) -> dict: - """ - Returns node-pairs that have a common hidden confounder + """Returns node-pairs that have a common hidden confounder. Returns: dict: Dictionary with keys equal to tuples of node-pairs diff --git a/causalAssembly/models_fcm.py b/causalAssembly/models_fcm.py index 2fc890f..00002fc 100644 --- a/causalAssembly/models_fcm.py +++ b/causalAssembly/models_fcm.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -12,6 +13,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ + from __future__ import annotations import logging @@ -52,6 +54,12 @@ class FCM: """ def __init__(self, name: str | None = None, seed: int = 2023): + """Inits the FCM class. + + Args: + name (str | None, optional): Name. Defaults to None. + seed (int, optional): Seed. Defaults to 2023. + """ self.name = name self._random_state = np.random.default_rng(seed=seed) self.__init_dag() @@ -76,6 +84,7 @@ def source_nodes(self) -> list: @property def causal_order(self) -> list[Symbol]: """Returns the causal order of the current graph. + Note that this order is in general not unique. To ensure uniqueness, we additionally sort lexicograpically. @@ -104,7 +113,7 @@ def edges(self) -> list[tuple]: @property def num_nodes(self) -> int: - """Number of nodes in the graph + """Number of nodes in the graph. Returns: int @@ -113,7 +122,7 @@ def num_nodes(self) -> int: @property def num_edges(self) -> int: - """Number of edges in the graph + """Number of edges in the graph. Returns: int @@ -122,7 +131,7 @@ def num_edges(self) -> int: @property def sparsity(self) -> float: - """Sparsity of the graph + """Sparsity of the graph. Returns: float: in [0,1] @@ -132,8 +141,7 @@ def sparsity(self) -> float: @property def ground_truth(self) -> pd.DataFrame: - """Returns the current ground truth as - pandas adjacency. + """Returns the current ground truth as pandas adjacency. Returns: pd.DataFrame: Adjacenccy matrix. @@ -141,8 +149,8 @@ def ground_truth(self) -> pd.DataFrame: return nx.to_pandas_adjacency(self.graph, weight=None) @property - def interventions(self) -> list: - """Returns all interventions performed on the original graph + def interventions(self) -> list[str]: + """Returns all interventions performed on the original graph. Returns: list: list of intervened upon nodes in do(x) notation. @@ -199,6 +207,7 @@ def parents_of(self, node: Symbol, which_graph: nx.DiGraph) -> list[Symbol]: def causal_order_of(self, which_graph: nx.DiGraph) -> list[Symbol]: """Returns the causal order of the chosen graph. + Note that this order is in general not unique. To ensure uniqueness, we additionally sort lexicograpically. @@ -208,7 +217,9 @@ def causal_order_of(self, which_graph: nx.DiGraph) -> list[Symbol]: return list(nx.lexicographical_topological_sort(which_graph, key=lambda x: str(x))) def source_nodes_of(self, which_graph: nx.DiGraph) -> list: - """Returns the source nodes of a chosen graph. This is mainly for + """Returns the source nodes of a chosen graph. + + This is mainly for choosing different mutilated DAGs. Args: @@ -224,8 +235,8 @@ def source_nodes_of(self, which_graph: nx.DiGraph) -> list: ] def input_fcm(self, fcm: list[Eq]): - """ - Automatically builds up DAG according to the FCM fed in. + """Automatically builds up DAG according to the FCM fed in. + Args: fcm (list): list of sympy equations generated as: ```[python] @@ -300,9 +311,12 @@ def sample( snr: None | float = 1 / 2, source_df: None | pd.DataFrame = None, ) -> pd.DataFrame: - """Draw samples from the joint distribution that factorizes - according to the DAG implied by the FCM fed in. To avoid - unexpected/unintended behavior, avoid defining fully + r"""Sample from joint. + + Draw samples from the joint distribution that factorizes + according to the DAG implied by the FCM fed in. + + To avoid unexpected/unintended behavior, avoid defining fully deterministic equation systems. If parameters in noise terms are additive and left unevaluated, they're set according to a chosen Signal-To-Noise (SNR) ratio. @@ -349,8 +363,9 @@ def interventional_sample( snr: None | float = 1 / 2, source_df: None | pd.DataFrame = None, ) -> pd.DataFrame: - """Draw samples from the interventional distribution that factorizes - according to the mutilated DAG after performing one or multiple + r"""Draw samples from the interventional distribution. + + that factorizes according to the mutilated DAG after performing one or multiple interventions. Otherwise the method behaves similar to sampling from the non-interventional joint distribution. By default samples are drawn from the first intervention you performed. If you intervened upon more than one node, @@ -409,7 +424,9 @@ def _sample( snr: None | float = 1 / 2, source_df: None | pd.DataFrame = None, ) -> pd.DataFrame: - """Draw samples from the joint distribution that factorizes + r"""Draw samples from the joint distribution. + + that factorizes according to the DAG implied by the FCM fed in. To avoid unexpected/unintended behavior, avoid defining fully deterministic equation systems. @@ -422,6 +439,7 @@ def _sample( Args: size (int): Number of samples to draw. + which_graph (nx.DiGraph): Which graph to sample from. additive_gaussian_noise (bool, optional): _description_. Defaults to False. snr (None | float, optional): Signal-to-noise ratio \\( SNR = \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\). @@ -439,7 +457,6 @@ def _sample( pd.DataFrame: Data frame with rows of lenght size and columns equal to the number of nodes in the graph. """ - if source_df is not None and not self.__source_df_condition(source_df): raise AssertionError("Names in source_df don't match nodenames in graph.") @@ -512,11 +529,11 @@ def _sample( + str(fcm_expr) + " according to the given SNR." ) - noise_var = df[str(order)].var() / snr + noise_var = df[str(order)].var() / snr # type: ignore df[str(order)] = df[str(order)] + sympy_sample( fcm_expr.atoms(RandomSymbol) .pop() - .subs(self.__unfree_symbol(fcm_expr), np.sqrt(noise_var)), + .subs(self.__unfree_symbol(fcm_expr), np.sqrt(noise_var)), # type: ignore size=size, seed=self._random_state, ) @@ -532,9 +549,11 @@ def _sample( ) else: noise = symbols("noise") - noise_var = df[str(order)].var() / snr + noise_var = df[str(order)].var() / snr # type: ignore df[str(noise)] = self._random_state.normal( - loc=0, scale=np.sqrt(noise_var), size=size + loc=0, + scale=np.sqrt(noise_var), # type: ignore + size=size, ) fcm_expr = which_graph.nodes[order]["term"] + noise df[str(order)] = self.__eval_expression(df=df, fcm_expr=fcm_expr) @@ -552,7 +571,7 @@ def __unfree_symbol(self, fcm_expr) -> set[Symbol]: }.pop() def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame: - """Eval given fcm_expression with the values in given dataframe + """Eval given fcm_expression with the values in given dataframe. Args: df (pd.DataFrame): Data frame. @@ -561,7 +580,6 @@ def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame: Returns: pd.DataFrame: Data frame after eval. """ - correct_order = list(ordered(fcm_expr.free_symbols)) # self.__return_ordered_args(fcm_expr) cols = [str(col) for col in correct_order] evaluator = lambdify(correct_order, fcm_expr, "scipy") @@ -569,11 +587,11 @@ def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame: return evaluator(*[df[col] for col in cols]) def __distribution_parameters_explicit(self, order: Symbol, which_graph: nx.DiGraph) -> bool: - """Returns true if distribution parameters - are given explicitly, not symbolically. + """Returns true if distribution parameters are given explicitly, not symbolically. Args: order (node): node in graph + which_graph (nx.DiGraph): which graph to choose. Returns: bool: @@ -596,7 +614,9 @@ def __source_df_condition(self, source_df: pd.DataFrame) -> bool: ) def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]): - """Specify hard or soft intervention. If you want to intervene + """Specify hard or soft intervention. + + If you want to intervene upon more than one node provide a list of nodes to intervene on and a list of corresponding values to set these nodes to. (see example). The mutilated dag will automatically be @@ -625,7 +645,6 @@ def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]): ``` """ - if not set(nodes_values.keys()).issubset(set(self.nodes)): raise AssertionError( "One or more nodes you want to intervene upon are not in the graph." @@ -640,11 +659,11 @@ def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]): mutilated_dag.remove_edges_from(edges_to_remove) mutilated_dag.nodes[node]["term"] = intervention.rhs - self.mutilated_dags[ - f"do({list(nodes_values.keys())})" - ] = mutilated_dag # specifiying the same set twice will override + self.mutilated_dags[f"do({list(nodes_values.keys())})"] = ( + mutilated_dag # specifiying the same set twice will override + ) - def show(self, header: str | None = None, with_nodenames: bool = True) -> plt: + def show(self, header: str | None = None, with_nodenames: bool = True): """Plots the current DAG. Args: @@ -659,10 +678,8 @@ def show(self, header: str | None = None, with_nodenames: bool = True) -> plt: header = "" return self._show(which_graph=self.graph, header=header, with_nodenames=with_nodenames) - def show_mutilated_dag( - self, which_intervention: str | int = 0, with_nodenames: bool = True - ) -> plt: - """Plot mutilated DAG + def show_mutilated_dag(self, which_intervention: str | int = 0, with_nodenames: bool = True): + """Plot mutilated DAG. Args: which_intervention (str | int, optional): Which interventional distribution @@ -683,8 +700,12 @@ def show_mutilated_dag( ) def _show(self, which_graph: nx.DiGraph, header: str, with_nodenames: bool): - """Plots the graph by giving extra weight to nodes - with high in- and out-degree. + """Plots the graph by giving extra weight to nodes with high in- and out-degree. + + Args: + which_graph (nx.DiGraph): _description_ + header (str): _description_ + with_nodenames (bool): _description_ """ cmap = plt.get_cmap("Blues") fig, ax = plt.subplots() @@ -712,10 +733,10 @@ def _show(self, which_graph: nx.DiGraph, header: str, with_nodenames: bool): vmax=1, node_color=[ (d + 10) / (max_in_degree + 10) for _, d in which_graph.in_degree(self.nodes) - ], + ], # type: ignore node_size=[ 500 * (d + 1) / (max_out_degree + 1) for _, d in which_graph.out_degree(self.nodes) - ], + ], # type: ignore ) if with_nodenames: diff --git a/causalAssembly/pdag.py b/causalAssembly/pdag.py index b709b11..c7964e7 100644 --- a/causalAssembly/pdag.py +++ b/causalAssembly/pdag.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -27,17 +28,32 @@ class PDAG: - """ - Class for dealing with partially directed graph i.e. - graphs that contain both directed and undirected edges. + """Class for dealing with partially directed graphs. + + i.e., graphs that contain both directed and undirected edges. """ def __init__( self, - nodes: list | None = None, - dir_edges: list[tuple] | None = None, - undir_edges: list[tuple] | None = None, + nodes: list[str] | list[int] | set[str] | set[int] | None = None, + dir_edges: list[tuple[str, str]] + | list[tuple[int, int]] + | set[tuple[str, str]] + | set[tuple[int, int]] + | None = None, + undir_edges: list[tuple[str, str]] + | list[tuple[int, int]] + | set[tuple[str, str]] + | set[tuple[int, int]] + | None = None, ): + """Inits the PDAG class. + + Args: + nodes (list | None, optional): _description_. Defaults to None. + dir_edges (list[tuple] | None, optional): _description_. Defaults to None. + undir_edges (list[tuple] | None, optional): _description_. Defaults to None. + """ if nodes is None: nodes = [] if dir_edges is None: @@ -80,7 +96,7 @@ def _add_undir_edge(self, i, j): self._undirected_neighbors[i].add(j) self._undirected_neighbors[j].add(i) - def children(self, node: str) -> set: + def children(self, node: str | int) -> set: """Gives all children of node `node`. Args: @@ -94,7 +110,7 @@ def children(self, node: str) -> set: else: return set() - def parents(self, node: str) -> set: + def parents(self, node: str | int) -> set: """Gives all parents of node `node`. Args: @@ -108,7 +124,7 @@ def parents(self, node: str) -> set: else: return set() - def neighbors(self, node: str) -> set: + def neighbors(self, node: str | int) -> set: """Gives all neighbors of node `node`. Args: @@ -122,9 +138,8 @@ def neighbors(self, node: str) -> set: else: return set() - def undir_neighbors(self, node: str) -> set: - """Gives all undirected neighbors - of node `node`. + def undir_neighbors(self, node: str | int) -> set: + """Gives all undirected neighbors of node `node`. Args: node (str): node in current PDAG. @@ -138,8 +153,7 @@ def undir_neighbors(self, node: str) -> set: return set() def is_adjacent(self, i: str, j: str) -> bool: - """Return True if the graph contains an directed - or undirected edge between i and j. + """Return True if the graph contains an directed or undirected edge between i and j. Args: i (str): node i. @@ -156,9 +170,7 @@ def is_adjacent(self, i: str, j: str) -> bool: ) def is_clique(self, potential_clique: set) -> bool: - """ - Check every pair of node X potential_clique is adjacent. - """ + """Check every pair of node X potential_clique is adjacent.""" return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2)) @classmethod @@ -172,7 +184,7 @@ def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> PDAG: PDAG """ assert pd_amat.shape[0] == pd_amat.shape[1] - nodes = pd_amat.columns + nodes = list(pd_amat.columns) all_connections = [] start, end = np.where(pd_amat != 0) @@ -188,7 +200,7 @@ def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> PDAG: return PDAG(nodes=nodes, dir_edges=dir_edges, undir_edges=undir_edges) def remove_edge(self, i: str, j: str): - """Removes edge in question + """Removes edge in question. Args: i (str): tail @@ -211,6 +223,7 @@ def remove_edge(self, i: str, j: str): def undir_to_dir_edge(self, tail: str, head: str): """Takes a undirected edge and turns it into a directed one. + tail indicates the starting node of the edge and head the end node, i.e. tail -> head. @@ -236,12 +249,12 @@ def undir_to_dir_edge(self, tail: str, head: str): self._add_dir_edge(i=tail, j=head) def remove_node(self, node): - """Remove a node from the graph""" + """Remove a node from the graph.""" self._nodes.remove(node) - self._dir_edges = {(i, j) for i, j in self._dir_edges if i != node and j != node} + self._dir_edges = {(i, j) for i, j in self._dir_edges if node not in (i, j)} - self._undir_edges = {(i, j) for i, j in self._undir_edges if i != node and j != node} + self._undir_edges = {(i, j) for i, j in self._undir_edges if node not in (i, j)} for child in self._children[node]: self._parents[child].remove(node) @@ -261,8 +274,7 @@ def remove_node(self, node): self._undirected_neighbors.pop(node, "I was never here") def to_dag(self) -> nx.DiGraph: - """ - Algorithm as described in Chickering (2002): + r"""Algorithm as described in Chickering (2002). 1. From PDAG P create DAG G containing all directed edges from P 2. Repeat the following: Select node v in P s.t. @@ -278,7 +290,6 @@ def to_dag(self) -> nx.DiGraph: Returns: nx.DiGraph: DAG that belongs to the MEC implied by the PDAG """ - pdag = self.copy() dag = nx.DiGraph() @@ -308,7 +319,7 @@ def to_dag(self) -> nx.DiGraph: for edge in pdag.undir_edges: if node in edge: incident_node = set(edge) - {node} - dag.add_edge(*incident_node, node) + dag.add_edge(*incident_node, node) # type: ignore pdag.remove_node(node) break @@ -324,7 +335,9 @@ def to_dag(self) -> nx.DiGraph: @property def adjacency_matrix(self) -> pd.DataFrame: - """Returns adjacency matrix where the i,jth + """Returns adjacency matrix. + + The i,jth entry being one indicates that there is an edge from i to j. A zero indicates that there is no edge. @@ -343,7 +356,9 @@ def adjacency_matrix(self) -> pd.DataFrame: return amat def _amat_to_dag(self) -> pd.DataFrame: - """Transform the adjacency matrix of an PDAG to the adjacency + """Adjacency matrix to random DAG. + + Transform the adjacency matrix of an PDAG to the adjacency matrix of a SOME DAG in the Markov equivalence class. Returns: @@ -375,7 +390,7 @@ def _amat_to_dag(self) -> pd.DataFrame: ) def vstructs(self) -> set: - """Retrieve v-structures + """Retrieve v-structures. Returns: set: set of all v-structures @@ -389,8 +404,8 @@ def vstructs(self) -> set: return vstructures def copy(self): - """Return a copy of the graph""" - return PDAG(nodes=self._nodes, dir_edges=self._dir_edges, undir_edges=self._undir_edges) + """Return a copy of the graph.""" + return PDAG(nodes=self._nodes, dir_edges=self._dir_edges, undir_edges=self._undir_edges) # type: ignore def show(self): """Plot PDAG.""" @@ -414,8 +429,8 @@ def to_networkx(self) -> nx.MultiDiGraph: return nx_pdag def _meek_mec_enumeration(self, pdag: PDAG, dag_list: list): - """Recursion algorithm which recursively applies the - following steps: + """Recursion algorithm which recursively applies the following steps. + 1. Orient the first undirected edge found. 2. Apply Meek rules. 3. Recurse with each direction of the oriented edge. @@ -455,8 +470,8 @@ def _meek_mec_enumeration(self, pdag: PDAG, dag_list: list): self._meek_mec_enumeration(pdag=g_copy, dag_list=dag_list) def to_allDAGs(self) -> list[nx.DiGraph]: - """Recursion algorithm which recursively applies the - following steps: + """Recursion algorithm which recursively applies the following steps. + 1. Orient the first undirected edge found. 2. Apply Meek rules. 3. Recurse with each direction of the oriented edge. @@ -473,8 +488,7 @@ def to_allDAGs(self) -> list[nx.DiGraph]: # use Meek's cpdag2alldag def _apply_meek_rules(self, G: PDAG) -> PDAG: - """Apply all four Meek rules to a - PDAG turning it into a CPDAG. + """Apply all four Meek rules to a PDAG turning it into a CPDAG. Args: G (PDAG): PDAG to complete @@ -511,13 +525,13 @@ def to_random_dag(self) -> nx.DiGraph: return nx.from_pandas_adjacency(to_dag_candidate.adjacency_matrix, create_using=nx.DiGraph) @property - def nodes(self) -> list: + def nodes(self) -> list[str] | list[int]: """Get all nods in current PDAG. Returns: list: list of nodes. """ - return sorted(list(self._nodes)) + return sorted(list(self._nodes)) # type: ignore @property def nnodes(self) -> int: @@ -530,8 +544,7 @@ def nnodes(self) -> int: @property def num_undir_edges(self) -> int: - """Number of undirected edges - in current PDAG. + """Number of undirected edges in current PDAG. Returns: int: Number of undirected edges @@ -540,8 +553,7 @@ def num_undir_edges(self) -> int: @property def num_dir_edges(self) -> int: - """Number of directed edges - in current PDAG. + """Number of directed edges in current PDAG. Returns: int: Number of directed edges @@ -550,8 +562,7 @@ def num_dir_edges(self) -> int: @property def num_adjacencies(self) -> int: - """Number of adjacent nodes - in current PDAG. + """Number of adjacent nodes in current PDAG. Returns: int: Number of adjacent nodes @@ -560,8 +571,7 @@ def num_adjacencies(self) -> int: @property def undir_edges(self) -> list[tuple]: - """Gives all undirected edges in - current PDAG. + """Gives all undirected edges in current PDAG. Returns: list[tuple]: List of undirected edges. @@ -570,8 +580,7 @@ def undir_edges(self) -> list[tuple]: @property def dir_edges(self) -> list[tuple]: - """Gives all directed edges in - current PDAG. + """Gives all directed edges in current PDAG. Returns: list[tuple]: List of directed edges. @@ -598,7 +607,9 @@ def vstructs(dag: nx.DiGraph) -> set: def rule_1(pdag: PDAG) -> PDAG: - """Given the following pattern X -> Y - Z. Orient Y - Z to Y -> Z + """Meeks first rule. + + Given the following pattern X -> Y - Z. Orient Y - Z to Y -> Z if X and Z are non-adjacent (otherwise a new v-structure arises). Args: @@ -625,7 +636,9 @@ def rule_1(pdag: PDAG) -> PDAG: def rule_2(pdag: PDAG) -> PDAG: - """Given the following directed triple + """Meeks 2nd rule. + + Given the following directed triple X -> Y -> Z where X - Z are indeed adjacent. Orient X - Z to X -> Z otherwise a cycle arises. @@ -653,7 +666,9 @@ def rule_2(pdag: PDAG) -> PDAG: def rule_3(pdag: PDAG) -> PDAG: - """Orient X - Z to X -> Z, whenever there are two triples + """Meeks third rule. + + Orient X - Z to X -> Z, whenever there are two triples X - Y1 -> Z and X - Y2 -> Z such that Y1 and Y2 are non-adjacent. Args: @@ -662,6 +677,7 @@ def rule_3(pdag: PDAG) -> PDAG: Returns: PDAG: PDAG after application of rule. """ + TWO = 2 copy_pdag = pdag.copy() for edge in copy_pdag.undir_edges: reverse_edge = edge[::-1] @@ -670,7 +686,7 @@ def rule_3(pdag: PDAG) -> PDAG: # if true that tail - node1 -> head and tail - node2 -> head # while {node1 U node2} = 0 then orient tail -> head orient = False - if len(copy_pdag.undir_neighbors(tail)) >= 2: + if len(copy_pdag.undir_neighbors(tail)) >= TWO: undir_n = copy_pdag.undir_neighbors(tail) selection = [ (node1, node2) @@ -688,7 +704,9 @@ def rule_3(pdag: PDAG) -> PDAG: def rule_4(pdag: PDAG) -> PDAG: - """Orient X - Y1 to X -> Y1, whenever there are + """Meeks 4th rule. + + Orient X - Y1 to X -> Y1, whenever there are two triples with X - Z and X - Y1 <- Z and X - Y2 -> Z such that Y1 and Y2 are non-adjacent. @@ -720,7 +738,7 @@ def rule_4(pdag: PDAG) -> PDAG: def dag2cpdag(dag: nx.DiGraph) -> PDAG: - """Convertes a DAG into its unique CPDAG + """Convertes a DAG into its unique CPDAG. Args: dag (nx.DiGraph): DAG the CPDAG corresponds to. @@ -728,7 +746,7 @@ def dag2cpdag(dag: nx.DiGraph) -> PDAG: Returns: PDAG: unique CPDAG """ - copy_dag = dag.copy() + copy_dag: nx.DiGraph = dag.copy() # type: ignore # Skeleton skeleton = nx.to_pandas_adjacency(copy_dag.to_undirected()) # v-Structures diff --git a/licenses.txt b/licenses.txt new file mode 100644 index 0000000..94610af --- /dev/null +++ b/licenses.txt @@ -0,0 +1,76 @@ + Name Version License + ghp-import 2.1.0 Apache Software License + importlib_metadata 8.7.0 Apache Software License + importlib_resources 6.5.2 Apache Software License + requests 2.32.4 Apache Software License + tzdata 2025.2 Apache Software License + watchdog 6.0.0 Apache Software License + packaging 25.0 Apache Software License; BSD License + python-dateutil 2.9.0.post0 Apache Software License; BSD License + verspec 0.1.0 Apache Software License; BSD License + uv 0.8.4 Apache Software License; MIT License + coverage 7.10.2 Apache-2.0 + Jinja2 3.1.6 BSD License + MarkupSafe 3.0.2 BSD License + Pygments 2.19.2 BSD License + babel 2.17.0 BSD License + colorama 0.4.6 BSD License + contourpy 1.3.2 BSD License + cycler 0.12.1 BSD License + idna 3.10 BSD License + kiwisolver 1.4.8 BSD License + mike 2.1.3 BSD License + mkdocs 1.6.1 BSD License + mpmath 1.3.0 BSD License + networkx 3.4.2 BSD License + nodeenv 1.9.1 BSD License + numpy 2.2.6 BSD License + pandas 2.3.1 BSD License + pip-tools 7.5.0 BSD License + pycparser 2.22 BSD License + scipy 1.15.3 BSD License + sympy 1.14.0 BSD License + rpy2-robjects 3.6.1 GNU General Public License v2 or later (GPLv2+) + rpy2-rinterface 3.6.2 GPL-2.0-or-later + fonttools 4.59.0 MIT + identify 2.6.12 MIT + pytest-cov 6.2.1 MIT + PyYAML 6.0.2 MIT License + backrefs 5.9 MIT License + cffi 1.17.1 MIT License + cfgv 3.4.0 MIT License + charset-normalizer 3.4.2 MIT License + exceptiongroup 1.3.0 MIT License + iniconfig 2.1.0 MIT License + mergedeep 1.3.4 MIT License + mkdocs-get-deps 0.2.0 MIT License + mkdocs-material 9.6.16 MIT License + mkdocs-material-extensions 1.3.1 MIT License + paginate 0.5.7 MIT License + platformdirs 4.3.8 MIT License + pluggy 1.6.0 MIT License + pre_commit 4.2.0 MIT License + pymdown-extensions 10.16.1 MIT License + pyparsing 3.2.3 MIT License + pyproject_hooks 1.2.0 MIT License + pytest 8.4.1 MIT License + pytz 2025.2 MIT License + ruff 0.12.7 MIT License + six 1.17.0 MIT License + tzlocal 5.3.1 MIT License + virtualenv 20.33.0 MIT License + distlib 0.4.0 Python Software Foundation License + matplotlib 3.10.5 Python Software Foundation License + filelock 3.18.0 The Unlicense (Unlicense) + Markdown 3.8.2 UNKNOWN + build 1.3.0 UNKNOWN + click 8.2.1 UNKNOWN + griffe 1.9.0 UNKNOWN + mkdocs-autorefs 1.4.2 UNKNOWN + mkdocstrings 0.30.0 UNKNOWN + mkdocstrings-python 1.16.12 UNKNOWN + pillow 11.3.0 UNKNOWN + pyyaml_env_tag 1.1 UNKNOWN + typing_extensions 4.14.1 UNKNOWN + urllib3 2.5.0 UNKNOWN + zipp 3.23.0 UNKNOWN diff --git a/pyproject.toml b/pyproject.toml index 98be8c5..8e7c1e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "requests", ] readme = "README.md" -license = {file = "LICENSE"} +license = { file = "LICENSE" } [project.urls] Homepage = "https://github.com/boschresearch/causalAssembly" @@ -38,10 +38,13 @@ dev = [ "mkdocs", "mkdocs-material", "mkdocstrings[python]", - "ruff", + "pip-licenses", "pip-tools", "pre-commit", "pytest", + "pytest-cov", + "ruff", + "uv", ] [tool.setuptools.packages.find] @@ -54,35 +57,62 @@ version = { file = "VERSION" } [tool.ruff] -select = ["A", "E", "F", "I"] -ignore = [] - -fixable = ["A", "B", "C", "D", "E", "F", "I"] -unfixable = [] - -line-length = 100 - -target-version = "py310" - exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".hg", + ".github", ".mypy_cache", - ".nox", - ".pants.d", + ".pytest_cache", ".ruff_cache", - ".svn", - ".tox", ".venv", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", "venv", + #"notebooks/", +] + +extend-include = ["*.ipynb"] + +line-length = 100 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle + "F", # pyflakes + "UP", # pyupgrade + "D", # pydocstyle + "PL", # pylint + "TD", # flake8-todos + "C90", # McCabe ] -per-file-ignores = {} + +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.mccabe] +max-complexity = 27 + +[tool.ruff.lint.pylint] +max-args = 15 +max-branches = 30 +max-statements = 100 + +[tool.pytest.ini_options] +addopts = "--cov=gresit --cov-fail-under=60" +testpaths = ["tests"] + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a788f3f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml --output-file=requirements.txt --annotation-style=line +certifi==2025.8.3 # via requests +cffi==1.17.1 # via rpy2-rinterface +charset-normalizer==3.4.2 # via requests +contourpy==1.3.2 # via matplotlib +cycler==0.12.1 # via matplotlib +fonttools==4.59.0 # via matplotlib +idna==3.10 # via requests +jinja2==3.1.6 # via rpy2-robjects +kiwisolver==1.4.8 # via matplotlib +markupsafe==3.0.2 # via jinja2 +matplotlib==3.10.5 # via causalassembly (pyproject.toml) +mpmath==1.3.0 # via sympy +networkx==3.4.2 # via causalassembly (pyproject.toml) +numpy==2.2.6 # via contourpy, matplotlib, pandas, scipy, causalassembly (pyproject.toml) +packaging==25.0 # via matplotlib +pandas==2.3.1 # via causalassembly (pyproject.toml) +pillow==11.3.0 # via matplotlib +pycparser==2.22 # via cffi +pyparsing==3.2.3 # via matplotlib +python-dateutil==2.9.0.post0 # via matplotlib, pandas +pytz==2025.2 # via pandas +requests==2.32.4 # via causalassembly (pyproject.toml) +rpy2==3.6.2 # via causalassembly (pyproject.toml) +rpy2-rinterface==3.6.2 # via rpy2, rpy2-robjects +rpy2-robjects==3.6.1 # via rpy2 +scipy==1.15.3 # via causalassembly (pyproject.toml) +six==1.17.0 # via python-dateutil +sympy==1.14.0 # via causalassembly (pyproject.toml) +tzdata==2025.2 # via pandas +tzlocal==5.3.1 # via rpy2-robjects +urllib3==2.5.0 # via requests diff --git a/requirements_dev.txt b/requirements_dev.txt index 3666155..3a26ad6 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -4,80 +4,90 @@ # # pip-compile --allow-unsafe --annotation-style=line --extra=dev --no-emit-index-url --no-emit-trusted-host --output-file=requirements_dev.txt pyproject.toml # -babel==2.14.0 # via mkdocs-material -build==1.0.3 # via pip-tools -certifi==2024.2.2 # via requests -cffi==1.16.0 # via rpy2 +babel==2.17.0 # via mkdocs-material +backrefs==5.9 # via mkdocs-material +build==1.3.0 # via pip-tools +certifi==2025.8.3 # via requests +cffi==1.17.1 # via rpy2-rinterface cfgv==3.4.0 # via pre-commit -charset-normalizer==3.3.2 # via requests -click==8.1.7 # via mkdocs, mkdocstrings, pip-tools +charset-normalizer==3.4.2 # via requests +click==8.2.1 # via mkdocs, pip-tools colorama==0.4.6 # via griffe, mkdocs-material -contourpy==1.2.0 # via matplotlib +contourpy==1.3.2 # via matplotlib +coverage[toml]==7.10.2 # via coverage, pytest-cov cycler==0.12.1 # via matplotlib -distlib==0.3.8 # via virtualenv -exceptiongroup==1.2.0 # via pytest -filelock==3.13.1 # via virtualenv -fonttools==4.48.1 # via matplotlib +distlib==0.4.0 # via virtualenv +exceptiongroup==1.3.0 # via pytest +filelock==3.18.0 # via virtualenv +fonttools==4.59.0 # via matplotlib ghp-import==2.1.0 # via mkdocs -griffe==0.40.1 # via mkdocstrings-python -identify==2.5.34 # via pre-commit -idna==3.6 # via requests -importlib-metadata==7.0.1 # via mike -importlib-resources==6.1.1 # via mike -iniconfig==2.0.0 # via pytest -jinja2==3.1.3 # via mike, mkdocs, mkdocs-material, mkdocstrings, rpy2 -kiwisolver==1.4.5 # via matplotlib -markdown==3.5.2 # via mkdocs, mkdocs-autorefs, mkdocs-material, mkdocstrings, pymdown-extensions -markupsafe==2.1.5 # via jinja2, mkdocs, mkdocstrings -matplotlib==3.8.2 # via causalAssembly (pyproject.toml) -mergedeep==1.3.4 # via mkdocs -mike==2.0.0 # via causalAssembly (pyproject.toml) -mkdocs==1.5.3 # via causalAssembly (pyproject.toml), mike, mkdocs-autorefs, mkdocs-material, mkdocstrings -mkdocs-autorefs==0.5.0 # via mkdocstrings -mkdocs-material==9.5.9 # via causalAssembly (pyproject.toml) +griffe==1.9.0 # via mkdocstrings-python +identify==2.6.12 # via pre-commit +idna==3.10 # via requests +importlib-metadata==8.7.0 # via mike +importlib-resources==6.5.2 # via mike +iniconfig==2.1.0 # via pytest +jinja2==3.1.6 # via mike, mkdocs, mkdocs-material, mkdocstrings, rpy2-robjects +kiwisolver==1.4.8 # via matplotlib +markdown==3.8.2 # via mkdocs, mkdocs-autorefs, mkdocs-material, mkdocstrings, pymdown-extensions +markupsafe==3.0.2 # via jinja2, mkdocs, mkdocs-autorefs, mkdocstrings +matplotlib==3.10.5 # via causalAssembly (pyproject.toml) +mergedeep==1.3.4 # via mkdocs, mkdocs-get-deps +mike==2.1.3 # via causalAssembly (pyproject.toml) +mkdocs==1.6.1 # via causalAssembly (pyproject.toml), mike, mkdocs-autorefs, mkdocs-material, mkdocstrings +mkdocs-autorefs==1.4.2 # via mkdocstrings, mkdocstrings-python +mkdocs-get-deps==0.2.0 # via mkdocs +mkdocs-material==9.6.16 # via causalAssembly (pyproject.toml) mkdocs-material-extensions==1.3.1 # via mkdocs-material -mkdocstrings[python]==0.24.0 # via causalAssembly (pyproject.toml), mkdocstrings-python -mkdocstrings-python==1.8.0 # via mkdocstrings +mkdocstrings[python]==0.30.0 # via causalAssembly (pyproject.toml), mkdocstrings-python +mkdocstrings-python==1.16.12 # via mkdocstrings mpmath==1.3.0 # via sympy -networkx==3.2.1 # via causalAssembly (pyproject.toml) -nodeenv==1.8.0 # via pre-commit -numpy==1.26.4 # via causalAssembly (pyproject.toml), contourpy, matplotlib, pandas, scipy -packaging==23.2 # via build, matplotlib, mkdocs, pytest -paginate==0.5.6 # via mkdocs-material -pandas==2.2.0 # via causalAssembly (pyproject.toml) +networkx==3.4.2 # via causalAssembly (pyproject.toml) +nodeenv==1.9.1 # via pre-commit +numpy==2.2.6 # via causalAssembly (pyproject.toml), contourpy, matplotlib, pandas, scipy +packaging==25.0 # via build, matplotlib, mkdocs, pytest +paginate==0.5.7 # via mkdocs-material +pandas==2.3.1 # via causalAssembly (pyproject.toml) pathspec==0.12.1 # via mkdocs -pillow==10.2.0 # via matplotlib -pip-tools==7.3.0 # via causalAssembly (pyproject.toml) -platformdirs==4.2.0 # via mkdocs, mkdocstrings, virtualenv -pluggy==1.4.0 # via pytest -pre-commit==3.6.1 # via causalAssembly (pyproject.toml) -pycparser==2.21 # via cffi -pygments==2.17.2 # via mkdocs-material -pymdown-extensions==10.7 # via mkdocs-material, mkdocstrings -pyparsing==3.1.1 # via matplotlib, mike -pyproject-hooks==1.0.0 # via build -pytest==8.0.0 # via causalAssembly (pyproject.toml) -python-dateutil==2.8.2 # via ghp-import, matplotlib, pandas -pytz==2024.1 # via pandas -pyyaml==6.0.1 # via mike, mkdocs, pre-commit, pymdown-extensions, pyyaml-env-tag -pyyaml-env-tag==0.1 # via mkdocs -regex==2023.12.25 # via mkdocs-material -requests==2.31.0 # via causalAssembly (pyproject.toml), mkdocs-material -rpy2==3.5.15 # via causalAssembly (pyproject.toml) -ruff==0.2.1 # via causalAssembly (pyproject.toml) -scipy==1.12.0 # via causalAssembly (pyproject.toml) -six==1.16.0 # via python-dateutil -sympy==1.12 # via causalAssembly (pyproject.toml) -tomli==2.0.1 # via build, pip-tools, pyproject-hooks, pytest -tzdata==2024.1 # via pandas -tzlocal==5.2 # via rpy2 -urllib3==2.2.0 # via requests +pillow==11.3.0 # via matplotlib +pip-licenses==5.0.0 # via causalAssembly (pyproject.toml) +pip-tools==7.5.0 # via causalAssembly (pyproject.toml) +platformdirs==4.3.8 # via mkdocs-get-deps, virtualenv +pluggy==1.6.0 # via pytest, pytest-cov +pre-commit==4.2.0 # via causalAssembly (pyproject.toml) +prettytable==3.16.0 # via pip-licenses +pycparser==2.22 # via cffi +pygments==2.19.2 # via mkdocs-material, pytest +pymdown-extensions==10.16.1 # via mkdocs-material, mkdocstrings +pyparsing==3.2.3 # via matplotlib, mike +pyproject-hooks==1.2.0 # via build, pip-tools +pytest==8.4.1 # via causalAssembly (pyproject.toml), pytest-cov +pytest-cov==6.2.1 # via causalAssembly (pyproject.toml) +python-dateutil==2.9.0.post0 # via ghp-import, matplotlib, pandas +pytz==2025.2 # via pandas +pyyaml==6.0.2 # via mike, mkdocs, mkdocs-get-deps, pre-commit, pymdown-extensions, pyyaml-env-tag +pyyaml-env-tag==1.1 # via mike, mkdocs +requests==2.32.4 # via causalAssembly (pyproject.toml), mkdocs-material +rpy2==3.6.2 # via causalAssembly (pyproject.toml) +rpy2-rinterface==3.6.2 # via rpy2, rpy2-robjects +rpy2-robjects==3.6.1 # via rpy2 +ruff==0.12.7 # via causalAssembly (pyproject.toml) +scipy==1.15.3 # via causalAssembly (pyproject.toml) +six==1.17.0 # via python-dateutil +sympy==1.14.0 # via causalAssembly (pyproject.toml) +tomli==2.2.1 # via build, coverage, pip-licenses, pip-tools, pytest +typing-extensions==4.14.1 # via exceptiongroup, mkdocstrings-python +tzdata==2025.2 # via pandas +tzlocal==5.3.1 # via rpy2-robjects +urllib3==2.5.0 # via requests +uv==0.8.4 # via causalAssembly (pyproject.toml) verspec==0.1.0 # via mike -virtualenv==20.25.0 # via pre-commit -watchdog==4.0.0 # via mkdocs -wheel==0.42.0 # via pip-tools -zipp==3.17.0 # via importlib-metadata +virtualenv==20.33.0 # via pre-commit +watchdog==6.0.0 # via mkdocs +wcwidth==0.2.13 # via prettytable +wheel==0.45.1 # via pip-tools +zipp==3.23.0 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: -pip==24.0 # via pip-tools -setuptools==69.1.0 # via nodeenv, pip-tools +pip==25.2 # via pip-tools +setuptools==80.9.0 # via pip-tools diff --git a/tests/test_dag.py b/tests/test_dag.py index f5586de..c53f065 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -12,6 +13,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ + import networkx as nx import numpy as np import pandas as pd @@ -22,29 +24,62 @@ class TestDAG: + """Test class for the DAG class. + + Returns: + _type_: _description_ + """ + @pytest.fixture(scope="class") def example_dag(self) -> DAG: + """Example dag. + + Returns: + DAG: _description_ + """ return DAG(nodes=["A", "B", "C"], edges=[("A", "B"), ("A", "C")]) def test_instance_is_created(self): + """Check whether an instance of DAG can be created with nodes only.""" dag = DAG(nodes=["A", "B", "C"]) assert isinstance(dag, DAG) def test_edges(self, example_dag: DAG): - assert example_dag.num_edges == 2 + """Test edges of the DAG. + + Args: + example_dag (DAG): _description_ + """ + TWO = 2 + assert example_dag.num_edges == TWO assert set(example_dag.edges) == {("A", "B"), ("A", "C")} def test_children(self, example_dag: DAG): + """Test children of the DAG. + + Args: + example_dag (DAG): _description_ + """ assert set(example_dag.children(of_node="A")) == {"B", "C"} assert example_dag.children(of_node="B") == [] assert example_dag.children(of_node="C") == [] def test_parents(self, example_dag: DAG): + """TEst parents of the DAG. + + Args: + example_dag (DAG): _description_ + """ assert example_dag.parents(of_node="A") == [] assert set(example_dag.parents(of_node="B")) == {"A"} assert set(example_dag.parents(of_node="C")) == {"A"} def test_from_pandas_adjacency(self, example_dag: DAG): + """Test import from pandas adjacency matrix. + + Args: + example_dag (DAG): _description_ + """ amat = pd.DataFrame( [[0, 1, 1], [0, 0, 0], [0, 0, 0]], columns=["A", "B", "C"], @@ -55,6 +90,11 @@ def test_from_pandas_adjacency(self, example_dag: DAG): assert set(from_pandas_pdag.edges) == set(example_dag.edges) def test_remove_edge(self, example_dag: DAG): + """Test removing edges. + + Args: + example_dag (DAG): _description_ + """ assert ("A", "C") in example_dag.edges example_dag.remove_edge("A", "C") assert ("A", "C") not in example_dag.edges @@ -62,34 +102,49 @@ def test_remove_edge(self, example_dag: DAG): example_dag.remove_edge("B", "A") def test_remove_node(self, example_dag: DAG): + """Test removing nodes. + + Args: + example_dag (DAG): _description_ + """ assert "C" in example_dag.nodes example_dag.remove_node("C") assert "C" not in example_dag.nodes def test_to_cpdag(self): + """Test conversion to CPDAG.""" + TWO = 2 dag = DAG() dag.add_edges_from([("A", "B"), ("A", "C")]) cpdag = dag.to_cpdag() assert isinstance(cpdag, PDAG) - assert cpdag.num_undir_edges == 2 + assert cpdag.num_undir_edges == TWO assert cpdag.num_dir_edges == 0 def test_adjacency_matrix(self, example_dag: DAG): + """Test return of adjacency matrix. + + Args: + example_dag (DAG): _description_ + """ amat = example_dag.adjacency_matrix assert amat.shape[0] == amat.shape[1] == example_dag.num_nodes assert amat.sum().sum() == example_dag.num_edges def test_to_networkx(self, example_dag: DAG): + """Test conversion to NetworkX graph.""" nxg = example_dag.to_networkx() assert isinstance(nxg, nx.DiGraph) assert set(nxg.edges) == set(example_dag.edges) def test_from_networkx(self, example_dag: DAG): + """Test conversion from NetworkX graph to DAG.""" nxg = example_dag.to_networkx() from_networkx_dag = DAG.from_nx(nxg) assert set(from_networkx_dag.edges) == set(example_dag.edges) def test_error_when_cyclic(self): + """Test that an error is raised when trying to create a cyclic DAG.""" dag = DAG() with pytest.raises(ValueError): dag.add_edges_from([("A", "C"), ("C", "D"), ("D", "A")]) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 3533340..6be09e1 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -12,6 +13,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ + import random import string @@ -24,8 +26,19 @@ class TestDAGmetrics: + """Test metrics class. + + Returns: + _type_: _description_ + """ + @pytest.fixture(scope="class") def gt(self): + """Set up ground truth adjacency matrix. + + Returns: + _type_: _description_ + """ names = list(string.ascii_lowercase)[0:5] temp_np = np.array( [ @@ -40,6 +53,11 @@ def gt(self): @pytest.fixture(scope="class") def est(self): + """Set up estimated adjacency matrix. + + Returns: + _type_: _description_ + """ names = list(string.ascii_lowercase)[0:5] temp_np = np.array( [ @@ -53,11 +71,23 @@ def est(self): return pd.DataFrame(temp_np, columns=names, index=names) def test_pd_input_works(self, gt, est): + """Test that the metrics class can be initialized with pandas DataFrames. + + Args: + gt (_type_): _description_ + est (_type_): _description_ + """ met = DAGmetrics(truth=gt, est=est) assert np.array_equal(met.truth, gt.to_numpy()) assert np.array_equal(met.est, est.to_numpy()) def test_nx_input_works(self, gt, est): + """Test that the metrics class can be initialized with networkx graphs. + + Args: + gt (_type_): _description_ + est (_type_): _description_ + """ gt_nx = nx.from_pandas_adjacency(gt, create_using=nx.DiGraph) est_nx = nx.from_pandas_adjacency(est, create_using=nx.DiGraph) met = DAGmetrics(truth=gt_nx, est=est_nx) @@ -65,6 +95,12 @@ def test_nx_input_works(self, gt, est): assert np.array_equal(met.est, est.to_numpy()) def test_pd_change_order(self, gt, est): + """Test that the metrics class can handle different node orders in input DataFrames. + + Args: + gt (_type_): _description_ + est (_type_): _description_ + """ nodelist = list(string.ascii_lowercase)[0:5] random.shuffle(nodelist) met = DAGmetrics(truth=gt, est=est, nodelist=nodelist) @@ -72,6 +108,12 @@ def test_pd_change_order(self, gt, est): assert np.array_equal(met.est, est.reindex(nodelist)[nodelist].to_numpy()) def test_nx_change_order(self, gt, est): + """Test that the metrics class can handle different node orders in networkx graphs. + + Args: + gt (_type_): _description_ + est (_type_): _description_ + """ nodelist = list(string.ascii_lowercase)[0:5] random.shuffle(nodelist) @@ -83,11 +125,18 @@ def test_nx_change_order(self, gt, est): assert np.array_equal(met.est, est.reindex(nodelist)[nodelist].to_numpy()) def test_metrics_values(self, gt, est): + """Test the metrics values. + + Args: + gt (_type_): _description_ + est (_type_): _description_ + """ + THREE = 3 met = DAGmetrics(truth=gt, est=est) - met.collect_metrics() + met = met.collect_metrics() - assert met.metrics["shd"] == 3 - assert met.metrics["gscore"] >= 0 and met.metrics["gscore"] <= 1 - assert met.metrics["f1"] >= 0 and met.metrics["f1"] <= 1 - assert met.metrics["recall"] >= 0 and met.metrics["recall"] <= 1 - assert met.metrics["precision"] >= 0 and met.metrics["precision"] <= 1 + assert met["shd"] == THREE + assert met["gscore"] >= 0 and met["gscore"] <= 1 + assert met["f1"] >= 0 and met["f1"] <= 1 + assert met["recall"] >= 0 and met["recall"] <= 1 + assert met["precision"] >= 0 and met["precision"] <= 1 diff --git a/tests/test_models_dag.py b/tests/test_models_dag.py index 9d3c7c2..73bc020 100644 --- a/tests/test_models_dag.py +++ b/tests/test_models_dag.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -12,6 +13,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ + import math import os import pickle @@ -20,29 +22,46 @@ import numpy as np import pandas as pd import pytest -from sympy.stats import Beta from causalAssembly.models_dag import NodeAttributes, ProcessCell, ProductionLineGraph class TestProcessCell: + """Test process Cell. + + Returns: + _type_: _description_ + """ + @pytest.fixture(scope="class") def cell(self): + """Set up a ProcessCell instance. + + Returns: + _type_: _description_ + """ c = ProcessCell(name="PYTEST") return c @pytest.fixture(scope="class") def module(self): + """Set up a module for testing. + + Returns: + _type_: _description_ + """ m = nx.DiGraph() m.add_nodes_from(["A", "B", "C"]) m.add_edges_from([("A", "B"), ("B", "C")]) return m def test_instance_is_created(self): + """Test whether an instance of ProcessCell can be created with a name.""" cell = ProcessCell(name="PYTEST") assert isinstance(cell, ProcessCell) def test_next_module_prefix_works(self): + """Test whether the next module prefix is generated correctly.""" # Arrange cell = ProcessCell(name="PYTEST") @@ -55,10 +74,18 @@ def test_next_module_prefix_works(self): assert cell.next_module_prefix() == "ABC1" def test_module_prefix_setter_works(self, cell): + """TEst that the module prefix can be set correctly. + + Args: + cell (_type_): _description_ + """ with pytest.raises(ValueError): cell.module_prefix = 1 def test_add_module_works(self, module): + """Test that a module can be added to the ProcessCell.""" + TWO = 2 + SIX = 6 # Arrange cell = ProcessCell(name="C01") cell.module_prefix = "M" @@ -68,11 +95,16 @@ def test_add_module_works(self, module): cell.add_module(module) # Assert - assert len(cell.modules) == 2 - assert len(cell.graph.nodes()) == 6 + assert len(cell.modules) == TWO + assert len(cell.graph.nodes()) == SIX assert cell.next_module_prefix() == "M3" def test_connect_by_module_works(self, module): + """Test that modules can be connected by edges. + + Args: + module (_type_): _description_ + """ # Arrange cell = ProcessCell(name="PyTestCell") m1 = cell.add_module(graph=module) @@ -92,6 +124,12 @@ def test_connect_by_module_works(self, module): ids=["source node invalid", "target node invalid"], ) def test_connect_by_module_fails_with_wrong_node_name(self, module, edges): + """Test that module connections fails when node names are invalid. + + Args: + module (_type_): _description_ + edges (_type_): _description_ + """ # Arrange cell = ProcessCell(name="PyTestCell") @@ -104,6 +142,11 @@ def test_connect_by_module_fails_with_wrong_node_name(self, module, edges): cell.connect_by_module(m1=m1, m2=m2, edges=edges) def test_node_property(self, module): + """Test properties of the nodes. + + Args: + module (_type_): _description_ + """ # Arrange cell = ProcessCell(name="PyTest") @@ -116,6 +159,11 @@ def test_node_property(self, module): assert len(cell.nodes) == expected_no_of_nodes def test_repr_is_working(self, module): + """Test repr is working. + + Args: + module (_type_): _description_ + """ # Arrange cell = ProcessCell(name="PyTest") @@ -131,6 +179,11 @@ def test_repr_is_working(self, module): ids=["sparsity=0.0", "sparsity=0.1", "sparsity=1.0"], ) def test_connect_by_random_edges(self, sparsity): + """Test whether connecting with random edges works. + + Args: + sparsity (_type_): _description_ + """ pline = ProductionLineGraph() pline.new_cell(name="C1") # Arrange @@ -149,6 +202,7 @@ def test_connect_by_random_edges(self, sparsity): assert len(c.graph.edges) == expected_edges def test_connect_by_random_edges_fails_with_cyclic_graph(self): + """Test failure with cyclic graphs.""" pline = ProductionLineGraph() pline.new_cell(name="C1") # Arrange @@ -166,6 +220,11 @@ def test_connect_by_random_edges_fails_with_cyclic_graph(self): c.connect_by_random_edges() def test_get_nodes_by_attribute(self, module): + """Test get nodes by attribute. + + Args: + module (_type_): _description_ + """ c = ProcessCell(name="C1") c.add_module(graph=module) @@ -174,15 +233,18 @@ def test_get_nodes_by_attribute(self, module): assert NodeAttributes.ALLOW_IN_EDGES in available_attributes def test_input_cellgraph_directly_works(self): + """Test whether cellgraph is inputted correctly.""" + THREE = 3 toygraph = nx.DiGraph() toygraph.add_edges_from([("a", "b"), ("a", "c"), ("b", "c")]) c = ProcessCell(name="toycell") c.input_cellgraph_directly(toygraph) - assert len(c.nodes) == 3 + assert len(c.nodes) == THREE assert c.nodes == ["toycell_a", "toycell_b", "toycell_c"] def test_ground_truth_cell(self): + """Test ground truth.""" pline = ProductionLineGraph() pline.new_cell(name="test") pline.test.add_random_module() @@ -195,11 +257,15 @@ def test_ground_truth_cell(self): class TestProductionLineGraph: + """Test ProductionLineGraph.""" + def test_instance_is_created(self): + """Test whether instance is created.""" p = ProductionLineGraph() assert isinstance(p, ProductionLineGraph) def test_getattr_works(self): + """Test getattr.""" # Arrange station_name = "Station1" p = ProductionLineGraph() @@ -211,6 +277,7 @@ def test_getattr_works(self): p.XXX def test_str_representation(self): + """Test str.""" p = ProductionLineGraph() p.new_cell(name="C1") p.C1.add_random_module(n_nodes=10) @@ -218,6 +285,7 @@ def test_str_representation(self): assert isinstance(str(p), str) def test_create_cell_works(self): + """Test cell creation.""" p = ProductionLineGraph() c1 = p.new_cell() @@ -225,12 +293,14 @@ def test_create_cell_works(self): assert c1.name == "C0" def test_create_cell_with_name_works(self): + """Test cell with name.""" p = ProductionLineGraph() c1 = p.new_cell(name="PyTest") assert c1.name == "PyTest" def test_append_same_cell_twice_fails(self): + """Test failure.""" p = ProductionLineGraph() p.new_cell(name="PyTest") @@ -238,18 +308,21 @@ def test_append_same_cell_twice_fails(self): p.new_cell(name="PyTest") def test_instance_via_cell_number_works(self): + """Test instance via cell number.""" n_cells = 10 p = ProductionLineGraph.via_cell_number(n_cells=n_cells) assert len(p.cells) == n_cells def test_if_graph_exists(self): + """Test existence.""" n_cells = 10 p = ProductionLineGraph.via_cell_number(n_cells=n_cells) assert isinstance(p.graph, nx.DiGraph) def test_add_eol_cell(self): + """Test eol cell.""" p = ProductionLineGraph() p.new_cell() p.new_cell(is_eol=True) @@ -257,6 +330,7 @@ def test_add_eol_cell(self): assert isinstance(p.eol_cell, ProcessCell) def test_add_eol_cell_twice_fails(self): + """Test eol twice fails.""" p = ProductionLineGraph() p.new_cell(is_eol=True) @@ -269,6 +343,7 @@ def test_add_eol_cell_twice_fails(self): ids=["some edges", "zero edges", "all edges"], ) def test_connect_cells_works_with_single_cell(self, n_nodes, forward_prob): + """Test connect.""" # Arrange p = ProductionLineGraph() p.new_cell(name="C1") @@ -305,6 +380,12 @@ def test_connect_cells_works_with_single_cell(self, n_nodes, forward_prob): ], ) def test_connect_cells_works_with_multiple_cells(self, n_nodes, forward_probs): + """Test connect with multiple cells. + + Args: + n_nodes (_type_): _description_ + forward_probs (_type_): _description_ + """ # Arrange p = ProductionLineGraph() @@ -335,6 +416,12 @@ def test_connect_cells_works_with_multiple_cells(self, n_nodes, forward_probs): ids=["some edges", "zero edges", "all edges"], ) def test_connect_cells_works_with_eol_cell(self, n_nodes, forward_prob): + """Test connect with eol. + + Args: + n_nodes (_type_): _description_ + forward_prob (_type_): _description_ + """ # Arrange p = ProductionLineGraph() @@ -357,6 +444,7 @@ def test_connect_cells_works_with_eol_cell(self, n_nodes, forward_prob): assert len(p.graph.edges) == no_of_expected_edges def test_connect_across_cells_manually(self): + """TEst manual connection.""" n_nodes = 10 prob = 0.1 p = ProductionLineGraph() @@ -376,6 +464,7 @@ def test_connect_across_cells_manually(self): assert len(p.graph.edges) == edges_in_C1 + edges_in_C2 + len(edgelist) def test_acyclicity_error(self): + """Test acycicity.""" n_nodes = 5 prob = 0.1 p = ProductionLineGraph() @@ -396,6 +485,7 @@ def test_acyclicity_error(self): print(p.graph.edges) def test_ground_truth_visible(self): + """Test ground truth visible.""" n_nodes = 5 prob = 0.1 p = ProductionLineGraph() @@ -414,6 +504,8 @@ def test_ground_truth_visible(self): assert len(p._pairs_with_hidden_mediators()) == 0 def test_ground_truth_hidden(self): + """TEst hidden gt.""" + TWO = 2 edges1 = [(1, 2), (2, 3)] edges2 = [(1, 3), (2, 3)] edges3 = [(1, 2), (2, 3)] @@ -444,10 +536,13 @@ def test_ground_truth_hidden(self): # {(C2_M1_1, C2_M1_2): C1_M1_3} assert p._pairs_with_hidden_confounders() == {("C2_M1_1", "C2_M1_2"): ["C1_M1_3"]} - assert p.ground_truth_visible.loc[("C2_M1_1", "C2_M1_2")] == 2 - assert p.ground_truth_visible.loc[("C2_M1_2", "C2_M1_1")] == 2 + assert p.ground_truth_visible.loc[("C2_M1_1", "C2_M1_2")] == TWO + assert p.ground_truth_visible.loc[("C2_M1_2", "C2_M1_1")] == TWO def test_input_cellgraph_directly(self): + """Test input directly.""" + SIX = 6 + FOUR = 4 dag1 = nx.DiGraph([(0, 1), (1, 2)]) dag2 = nx.DiGraph([(3, 4), (3, 5)]) @@ -457,20 +552,24 @@ def test_input_cellgraph_directly(self): testline.new_cell(name="Station2") testline.Station2.input_cellgraph_directly(graph=dag2) - assert testline.num_nodes == 6 - assert testline.num_edges == 4 + assert testline.num_nodes == SIX + assert testline.num_edges == FOUR assert testline.sparsity == pytest.approx(4 / math.comb(6, 2)) def test_drf_size(self): + """TEst drf size.""" testline = ProductionLineGraph() assert not testline.drf def test_drf_error(self): + """TEst drf error.""" testline = ProductionLineGraph() with pytest.raises(ValueError): testline.sample_from_drf() def test_from_nx(self): + """Test from nx.""" + TWO = 2 nx_graph = nx.DiGraph( [("1", "2"), ("1", "3"), ("1", "4"), ("2", "5"), ("2", "6"), ("5", "6")] ) @@ -478,7 +577,7 @@ def test_from_nx(self): cell_mapper = {"cell1": ["1", "2", "3", "4"], "cell2": ["5", "6"]} pline_from_nx = ProductionLineGraph.from_nx(g=nx_graph, cell_mapper=cell_mapper) - assert len(pline_from_nx.cells) == 2 + assert len(pline_from_nx.cells) == TWO assert set(pline_from_nx.cell1.nodes) == { "cell1_1", "cell1_2", @@ -490,18 +589,30 @@ def test_from_nx(self): ProductionLineGraph.from_nx(pd_graph, cell_mapper=cell_mapper) def test_save_and_load_drf(self, tmp_path_factory): + """Test save and load. + + Args: + tmp_path_factory (_type_): _description_ + """ basedir = tmp_path_factory.mktemp("data") filename = "drf.pkl" line1 = ProductionLineGraph() - line1.drf = np.array([[1, 2, 3]]) + line1.drf = np.array([[1, 2, 3]]) # type: ignore line1.save_drf(filename=filename, location=basedir) line2 = ProductionLineGraph() line2.drf = ProductionLineGraph.load_drf(filename=filename, location=basedir) - assert np.array_equal(line2.drf, np.array([[1, 2, 3]])) + assert np.array_equal(line2.drf, np.array([[1, 2, 3]])) # type: ignore def test_pickleability(self, tmp_path): + """Test pickle. + + Args: + tmp_path (_type_): _description_ + """ + SEVEN = 7 + TEN = 10 # Arrange filename_path = os.path.join(tmp_path, "pline.pkl") pline = ProductionLineGraph() @@ -523,10 +634,11 @@ def test_pickleability(self, tmp_path): # Assert assert new_edge in pline_reloaded.edges - assert len(pline.Station1.nodes) == 7 - assert len(pline.Station2.nodes) == 10 + assert len(pline.Station1.nodes) == SEVEN + assert len(pline.Station2.nodes) == TEN def test_copy(self): + """Test copy.""" pline = ProductionLineGraph() pline.new_cell(name="Station1") pline.new_cell(name="Station2") @@ -543,6 +655,7 @@ def test_copy(self): assert pline.cell_order == copyline.cell_order def test_within_edges_with_empty_cells_raises_error(self): + """Test error.""" # Setup pline = ProductionLineGraph() @@ -551,6 +664,7 @@ def test_within_edges_with_empty_cells_raises_error(self): pline.within_adjacency def test_between_edges_with_empty_cells_raises_error(self): + """Test error.""" # Setup pline = ProductionLineGraph() @@ -559,6 +673,7 @@ def test_between_edges_with_empty_cells_raises_error(self): print(pline.between_adjacency) def test_within_edges_adjacency_matrix(self): + """Test within amat.""" # Setup nx_graph = nx.DiGraph( [("1", "2"), ("1", "3"), ("1", "4"), ("2", "5"), ("2", "6"), ("5", "6")] @@ -583,9 +698,11 @@ def test_within_edges_adjacency_matrix(self): assert ( within_amat.loc["cell1_2", :].sum() == 0 and pline.ground_truth.loc["cell1_2", :].sum() != 0 - ) + ) # type: ignore def test_between_edges_adjacency_matrix(self): + """Test between edges amat.""" + TWO = 2 # Setup nx_graph = nx.DiGraph( [("1", "2"), ("1", "3"), ("1", "4"), ("2", "5"), ("2", "6"), ("5", "6")] @@ -599,13 +716,14 @@ def test_between_edges_adjacency_matrix(self): # Assert assert ( - between_amat.loc["cell1_2", :].sum() == 2 - and pline.ground_truth.loc["cell1_2", :].sum() == 2 - ) + between_amat.loc["cell1_2", :].sum() == TWO + and pline.ground_truth.loc["cell1_2", :].sum() == TWO + ) # type: ignore assert between_amat.loc[pline.cell1.nodes, pline.cell1.nodes].sum().sum() == 0 assert between_amat.loc[pline.cell2.nodes, pline.cell2.nodes].sum().sum() == 0 def test_interventional_drf_error(self): + """Test interventional drf.""" testline = ProductionLineGraph() with pytest.raises(ValueError): testline.sample_from_interventional_drf() diff --git a/tests/test_models_fcm.py b/tests/test_models_fcm.py index 14e7292..7824d01 100644 --- a/tests/test_models_fcm.py +++ b/tests/test_models_fcm.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -24,8 +25,19 @@ class TestFCM: + """Test FCM class. + + Returns: + _type_: _description_ + """ + @pytest.fixture(scope="class") def example_fcm(self): + """Set up example fcm. + + Returns: + _type_: _description_ + """ x, y, z = symbols("x,y,z") eq_x = Eq(x, Uniform("error", left=-1, right=1)) @@ -35,12 +47,17 @@ def example_fcm(self): eq_list = [eq_x, eq_y, eq_z] example_fcm = FCM(name="example_fcm", seed=2023) - example_fcm.input_fcm(eq_list) + example_fcm.input_fcm(eq_list) # type: ignore return example_fcm @pytest.fixture(scope="class") def medium_example_fcm(self) -> FCM: + """Set up medium size exmaple. + + Returns: + FCM: _description_ + """ v, x, y, z = symbols("v,x,y,z") eq_x = Eq(x, Normal("error", 0, 1)) @@ -51,24 +68,29 @@ def medium_example_fcm(self) -> FCM: eq_list = [eq_v, eq_x, eq_y, eq_z] test_fcm = FCM(name="testing", seed=2023) - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore return test_fcm def test_instance_is_created(self): + """TEst instance created.""" h = FCM(name="mymodel", seed=1234) assert isinstance(h, FCM) def test_input_fcm_works(self, example_fcm): + """Test input works.""" + THREE = 3 # Act x, y = symbols("x,y") # Assert assert len(example_fcm.source_nodes) == 1 - assert example_fcm.num_nodes == 3 - assert example_fcm.num_edges == 3 + assert example_fcm.num_nodes == THREE + assert example_fcm.num_edges == THREE assert (x, y) in example_fcm.edges def test_empty_graph_works(self): + """Test empty graph works.""" + THREE = 3 # Arrange x, y, z = symbols("x,y,z") @@ -78,24 +100,31 @@ def test_empty_graph_works(self): # Act test_fcm = FCM() - test_fcm.input_fcm([eq_x, eq_y, eq_z]) + test_fcm.input_fcm([eq_x, eq_y, eq_z]) # type: ignore df = test_fcm.sample(size=5) # Assert - assert test_fcm.num_nodes == 3 + assert test_fcm.num_nodes == THREE assert test_fcm.num_edges == 0 assert df.shape == (5, 3) def test_draw_without_noise_works(self, example_fcm): + """Test w/o noise. + + Args: + example_fcm (_type_): _description_ + """ + TEN = 10 # Act df = example_fcm.sample(size=10, additive_gaussian_noise=False) # Assert - assert len(df) == 10 + assert len(df) == TEN def test_draw_with_noise_works(self, example_fcm): + """Test w/o noise.""" # Act df_without_noise = example_fcm.sample(size=10, additive_gaussian_noise=False) df_with_noise = example_fcm.sample(size=10, additive_gaussian_noise=True) @@ -106,6 +135,11 @@ def test_draw_with_noise_works(self, example_fcm): assert not np.allclose(df_without_noise[col], df_with_noise[col]) def test_draw_from_dataframe(self, example_fcm: FCM): + """Test draw from df. + + Args: + example_fcm (FCM): _description_ + """ # Arrange source_df = pd.DataFrame() source_df["x"] = [0, 1.0, 10.0] @@ -125,6 +159,7 @@ def test_draw_from_dataframe(self, example_fcm: FCM): ) def test_draw_from_wrong_dataframe_raises_assertionerror(self, example_fcm): + """Test draw error.""" source_df = pd.DataFrame() source_df["AAA"] = [0, 1.0, 10.0] @@ -133,6 +168,8 @@ def test_draw_from_wrong_dataframe_raises_assertionerror(self, example_fcm): example_fcm.sample(size=3, source_df=source_df, additive_gaussian_noise=False) def test_specify_individual_noise(self): + """Test individual noise.""" + THREE = 3 # Arrange x, y, z = symbols("x,y,z") @@ -143,7 +180,7 @@ def test_specify_individual_noise(self): eq_list = [eq_x, eq_y, eq_z] test_fcm = FCM(name="testing", seed=2023) - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore # Act @@ -151,10 +188,11 @@ def test_specify_individual_noise(self): # Assert - assert test_fcm.num_edges == 3 + assert test_fcm.num_edges == THREE assert df.shape == (10, 3) def test_order_in_eval_always_correct(self): + """Test order is correct when evaluating.""" # Arrange v, w, x, y, z = symbols("v,w,x,y,z") eq_v = Eq(v, Uniform("noise", left=0.2, right=0.8)) @@ -167,7 +205,7 @@ def test_order_in_eval_always_correct(self): # Act test_fcm = FCM() - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore df = test_fcm.sample(size=10) # Assert assert all(np.isclose(df["x"], 27 - df["v"])) @@ -175,6 +213,8 @@ def test_order_in_eval_always_correct(self): assert all(np.isclose(df["z"], (df["x"] + df["y"]) / df["v"])) def test_select_scale_parameter_via_snr(self): + """Test scale param.""" + THREE = 3 # Arrange x, y, z = symbols("x,y,z") sigma = Symbol("sigma", positive=True) @@ -186,7 +226,7 @@ def test_select_scale_parameter_via_snr(self): eq_list = [eq_x, eq_y, eq_z] test_fcm = FCM(name="testing", seed=2023) - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore # Act @@ -194,10 +234,11 @@ def test_select_scale_parameter_via_snr(self): # Assert - assert test_fcm.num_edges == 3 + assert test_fcm.num_edges == THREE assert df.shape == (10, 3) def test_select_scale_parameter_via_snr_gives_error_when_not_additive(self): + """TEst sclae param set via snr.""" # Arrange x, y, z = symbols("x,y,z") sigma = Symbol("sigma", positive=True) @@ -209,13 +250,19 @@ def test_select_scale_parameter_via_snr_gives_error_when_not_additive(self): eq_list = [eq_x, eq_y, eq_z] test_fcm = FCM(name="testing", seed=2023) - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore # Assert with pytest.raises(ValueError): test_fcm.sample(size=10, snr=0.6) def test_data_frame_with_fewer_columns_than_source_nodes(self, medium_example_fcm: FCM): + """Test df with few cols. + + Args: + medium_example_fcm (FCM): _description_ + """ + FOUR = 4 # Act source_df = pd.DataFrame( { @@ -226,10 +273,15 @@ def test_data_frame_with_fewer_columns_than_source_nodes(self, medium_example_fc df = medium_example_fcm.sample(size=10, snr=0.6, source_df=source_df) # Assert - assert medium_example_fcm.num_edges == 4 + assert medium_example_fcm.num_edges == FOUR assert df.shape == (10, 4) def test_data_frame_too_few_rows(self, medium_example_fcm: FCM): + """Test df too few rows. + + Args: + medium_example_fcm (FCM): _description_ + """ # Act source_df = pd.DataFrame( { @@ -243,6 +295,7 @@ def test_data_frame_too_few_rows(self, medium_example_fcm: FCM): medium_example_fcm.sample(size=10, snr=0.6, source_df=source_df) def test_polynomial_equation_works(self): + """Test polynomial eq.""" # Arrange x, y = symbols("x,y") @@ -250,12 +303,14 @@ def test_polynomial_equation_works(self): eq_y = Eq(y, x**2 - 2 * x + 5) test_fcm = FCM() - test_fcm.input_fcm([eq_x, eq_y]) + test_fcm.input_fcm([eq_x, eq_y]) # type: ignore # Act df = test_fcm.sample(size=5) assert all(df["y"] == df["x"] ** 2 - 2 * df["x"] + 5) def test_display_functions_works(self): + """Test display.""" + FOUR = 4 # Arrange v, x, y, z = symbols("v,x,y,z") @@ -267,17 +322,18 @@ def test_display_functions_works(self): eq_list = [eq_v, eq_x, eq_y, eq_z] test_fcm = FCM(name="testing", seed=2023) - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore # Act functions_dict = test_fcm.display_functions() # Assert assert isinstance(functions_dict, dict) - assert len(functions_dict) == 4 + assert len(functions_dict) == FOUR assert str(functions_dict[x]) == str(functions_dict[v]) == "error" assert functions_dict[y] == eq_y.rhs assert functions_dict[z] == eq_z.rhs def test_show_individual_function(self): + """Test indiv. functions.""" # Arrange v, x, y, z = symbols("v,x,y,z") @@ -289,16 +345,17 @@ def test_show_individual_function(self): eq_list = [eq_v, eq_x, eq_y, eq_z] test_fcm = FCM(name="testing", seed=2023) - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore # Assert assert isinstance(test_fcm.function_of(node=x), dict) with pytest.raises(AssertionError): - test_fcm.function_of(node="x") + test_fcm.function_of(node="x") # type: ignore with pytest.raises(AssertionError): - test_fcm.function_of(node="m") + test_fcm.function_of(node="m") # type: ignore assert test_fcm.function_of(node=y) == {y: eq_y.rhs} def test_single_hard_intervention(self): + """Test hard inter.""" # Arrange x, y, z = symbols("x,y,z") @@ -309,7 +366,7 @@ def test_single_hard_intervention(self): eq_list = [eq_x, eq_y, eq_z] test_fcm = FCM() - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore # Act test_fcm.intervene_on({y: 2.5}) @@ -324,7 +381,9 @@ def test_single_hard_intervention(self): assert np.isclose(test_df["y"].mean(), 2.5) def test_single_soft_intervention(self): + """TEst soft interv.""" # Arrange + point_three = 0.3 x, y, z = symbols("x,y,z") eq_x = Eq(x, Normal("error", 0, 1)) @@ -334,7 +393,7 @@ def test_single_soft_intervention(self): eq_list = [eq_x, eq_y, eq_z] test_fcm = FCM() - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore # Act test_fcm.intervene_on({z: Uniform("noise", left=-0.3, right=0.3)}) @@ -346,9 +405,10 @@ def test_single_soft_intervention(self): assert test_fcm.interventions[0] == "do([z])" assert len(test_fcm.mutilated_dags) == 1 assert len(test_fcm.mutilated_dags["do([z])"].edges()) == test_fcm.num_edges - 2 - assert -0.3 < test_df["z"].min() and test_df["z"].min() < 0.3 + assert -point_three < test_df["z"].min() and test_df["z"].min() < point_three def test_multiple_interventions_at_once(self): + """Test mutliple interv.""" # Arrange v, x, y, z = symbols("v,x,y,z") @@ -360,7 +420,7 @@ def test_multiple_interventions_at_once(self): eq_list = [eq_v, eq_x, eq_y, eq_z] test_fcm = FCM() - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore # Act test_fcm.intervene_on({v: 5, z: Normal("error", 0, 1)}) @@ -374,6 +434,8 @@ def test_multiple_interventions_at_once(self): assert np.isclose(test_df["v"].mean(), 5) def test_multiple_interventions_sequentually(self): + """Test interv. sequentially.""" + TWO = 2 # Arrange v, x, y, z = symbols("v,x,y,z") @@ -385,7 +447,7 @@ def test_multiple_interventions_sequentually(self): eq_list = [eq_v, eq_x, eq_y, eq_z] test_fcm = FCM() - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore # Act test_fcm.intervene_on({v: 5}) @@ -395,13 +457,14 @@ def test_multiple_interventions_sequentually(self): test_fcm.interventional_sample(size=5, which_intervention=1) # Assert - assert len(test_fcm.interventions) == 2 - assert len(test_fcm.mutilated_dags) == 2 + assert len(test_fcm.interventions) == TWO + assert len(test_fcm.mutilated_dags) == TWO assert len(test_fcm.mutilated_dags["do([v])"].edges()) == test_fcm.num_edges - 1 assert len(test_fcm.mutilated_dags["do([z])"].edges()) == test_fcm.num_edges - 2 assert np.isclose(test_df_1["v"].mean(), 5) def test_reproducability(self): + """Test repoducability.""" # Arrange v, x, y, z, sigma = symbols("v,x,y,z,sigma") @@ -414,7 +477,7 @@ def test_reproducability(self): # Act test_fcm = FCM(name="testing", seed=2023) - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore source_df = pd.DataFrame( { @@ -425,7 +488,7 @@ def test_reproducability(self): df_a = test_fcm.sample(size=5, snr=2 / 3, additive_gaussian_noise=True, source_df=source_df) test_fcm = FCM(name="testing", seed=2023) - test_fcm.input_fcm(eq_list) + test_fcm.input_fcm(eq_list) # type: ignore df_b = test_fcm.sample(size=5, snr=2 / 3, additive_gaussian_noise=True, source_df=source_df) diff --git a/tests/test_pdag.py b/tests/test_pdag.py index f73d7c1..b9d288c 100644 --- a/tests/test_pdag.py +++ b/tests/test_pdag.py @@ -1,4 +1,5 @@ -""" Utility classes and functions related to causalAssembly. +"""Utility classes and functions related to causalAssembly. + Copyright (c) 2023 Robert Bosch GmbH This program is free software: you can redistribute it and/or modify @@ -20,10 +21,17 @@ from causalAssembly.pdag import PDAG, dag2cpdag +TWO = 2 +THREE = 3 +FOUR = 4 + class TestPDAG: + """Test PDAG class.""" + @pytest.fixture(scope="class") def mixed_pdag(self) -> PDAG: + """Set up pdag.""" pdag = PDAG( nodes=["A", "B", "C"], dir_edges=[("A", "B"), ("A", "C")], @@ -32,43 +40,55 @@ def mixed_pdag(self) -> PDAG: return pdag def test_instance_is_created(self): + """Test instance.""" pdag = PDAG(nodes=["A", "B", "C"]) assert isinstance(pdag, PDAG) def test_dir_edges(self): + """Test dir edges.""" pdag = PDAG(nodes=["A", "B", "C"], dir_edges=[("A", "B"), ("A", "C")]) - assert pdag.num_dir_edges == 2 + assert pdag.num_dir_edges == TWO assert pdag.num_undir_edges == 0 assert set(pdag.dir_edges) == {("A", "B"), ("A", "C")} def test_undir_edges(self): + """Test undir edges.""" pdag = PDAG(nodes=["A", "B", "C"], undir_edges=[("A", "B"), ("A", "C")]) assert pdag.num_dir_edges == 0 - assert pdag.num_undir_edges == 2 + assert pdag.num_undir_edges == TWO assert set(pdag.undir_edges) == {("A", "B"), ("A", "C")} def test_mixed_edges(self, mixed_pdag: PDAG): - assert mixed_pdag.num_dir_edges == 2 + """Test mixed edges.""" + assert mixed_pdag.num_dir_edges == TWO assert mixed_pdag.num_undir_edges == 1 assert set(mixed_pdag.dir_edges) == {("A", "B"), ("A", "C")} assert set(mixed_pdag.undir_edges) == {("B", "C")} def test_children(self, mixed_pdag: PDAG): + """Test child edges.""" assert mixed_pdag.children(node="A") == {"B", "C"} assert mixed_pdag.children(node="B") == set() assert mixed_pdag.children(node="C") == set() def test_parents(self, mixed_pdag: PDAG): + """Test parent edges.""" assert mixed_pdag.parents(node="A") == set() assert mixed_pdag.parents(node="B") == {"A"} assert mixed_pdag.parents(node="C") == {"A"} def test_neighbors(self, mixed_pdag: PDAG): + """Test neighbors.""" assert mixed_pdag.neighbors(node="C") == {"B", "A"} assert mixed_pdag.undir_neighbors(node="C") == {"B"} assert mixed_pdag.is_adjacent(i="B", j="C") def test_from_pandas_adjacency(self, mixed_pdag: PDAG): + """Test import from pandas. + + Args: + mixed_pdag (PDAG): _description_ + """ amat = pd.DataFrame( [[0, 1, 1], [0, 0, 1], [0, 1, 0]], columns=["A", "B", "C"], @@ -80,6 +100,11 @@ def test_from_pandas_adjacency(self, mixed_pdag: PDAG): assert from_pandas_pdag.num_undir_edges == mixed_pdag.num_undir_edges def test_remove_edge(self, mixed_pdag: PDAG): + """Test remove edges. + + Args: + mixed_pdag (PDAG): _description_ + """ assert ("A", "C") in mixed_pdag.dir_edges mixed_pdag.remove_edge("A", "C") assert ("A", "C") not in mixed_pdag.dir_edges @@ -87,6 +112,11 @@ def test_remove_edge(self, mixed_pdag: PDAG): mixed_pdag.remove_edge("B", "A") def test_change_undir_edge_to_dir_edge(self, mixed_pdag: PDAG): + """Test change. + + Args: + mixed_pdag (PDAG): _description_ + """ assert ("B", "C") in mixed_pdag.undir_edges or ( "C", "B", @@ -98,25 +128,41 @@ def test_change_undir_edge_to_dir_edge(self, mixed_pdag: PDAG): assert ("C", "B") not in mixed_pdag.undir_edges def test_remove_node(self, mixed_pdag: PDAG): + """Test remove nodes. + + Args: + mixed_pdag (PDAG): _description_ + """ assert "C" in mixed_pdag.nodes mixed_pdag.remove_node("C") assert "C" not in mixed_pdag.nodes def test_to_dag(self, mixed_pdag: PDAG): + """Test conversion to DAG. + + Args: + mixed_pdag (PDAG): _description_ + """ dag = mixed_pdag.to_dag() assert nx.is_directed_acyclic_graph(dag) assert set(mixed_pdag.dir_edges).issubset(set(dag.edges)) def test_adjacency_matrix(self, mixed_pdag: PDAG): + """Test return of adjacency matrix. + + Args: + mixed_pdag (PDAG): _description_ + """ amat = mixed_pdag.adjacency_matrix assert amat.shape[0] == amat.shape[1] == mixed_pdag.nnodes assert amat.sum().sum() == mixed_pdag.num_dir_edges + 2 * mixed_pdag.num_undir_edges def test_dag2cpdag(self): + """Test conversion from DAG to CPDAG.""" dag1 = nx.DiGraph([("1", "2"), ("2", "3"), ("3", "4")]) cpdag1 = dag2cpdag(dag=dag1) assert cpdag1.num_dir_edges == 0 - assert cpdag1.num_undir_edges == 3 + assert cpdag1.num_undir_edges == THREE dag2 = nx.DiGraph([("1", "3"), ("2", "3")]) cpdag2 = dag2cpdag(dag=dag2) @@ -125,10 +171,11 @@ def test_dag2cpdag(self): dag3 = nx.DiGraph([("1", "3"), ("2", "3"), ("1", "4")]) cpdag3 = dag2cpdag(dag=dag3) - assert cpdag3.num_dir_edges == 2 + assert cpdag3.num_dir_edges == TWO assert cpdag3.num_undir_edges == 1 def test_example_a_to_allDAGs(self): + """Test example PDAG to allDAGs.""" # Set up CPDAG: a - c - b -> MEC has 3 Members pdag = nx.Graph([("a", "c"), ("b", "c")]) amat = nx.to_pandas_adjacency(pdag) @@ -139,10 +186,11 @@ def test_example_a_to_allDAGs(self): # Act all_dags = example_pdag.to_allDAGs() - assert len(all_dags) == 3 + assert len(all_dags) == THREE assert all([isinstance(dag, nx.DiGraph) for dag in all_dags]) def test_example_b_to_allDAGs(self): + """Test example PDAG to allDAGs.""" # Set up CPDAG: a - (b,c,d) -> MEC has 4 Members pdag = nx.Graph([("a", "b"), ("a", "c"), ("a", "d")]) amat = nx.to_pandas_adjacency(pdag) @@ -153,10 +201,11 @@ def test_example_b_to_allDAGs(self): # Act all_dags = example_pdag.to_allDAGs() - assert len(all_dags) == 4 + assert len(all_dags) == FOUR assert all([isinstance(dag, nx.DiGraph) for dag in all_dags]) def test_empty_graph_to_allDAGs(self): + """Test empty graph to allDAGs.""" # Set up empty PDAG, has exaclty one DAG that is the same as the PDAG. pdag = nx.Graph() pdag.add_nodes_from(["a", "b", "c", "d"]) @@ -173,6 +222,7 @@ def test_empty_graph_to_allDAGs(self): assert set(all_dags[0].nodes) == set(pdag.nodes) def test_to_random_dag(self): + """Test to random DAG.""" # Set up CPDAG: a - (b,c,d) -> MEC has 4 Members pdag = nx.Graph([("a", "b"), ("a", "c"), ("a", "d")]) amat = nx.to_pandas_adjacency(pdag)