diff --git a/.github/workflows/test_package.yml b/.github/workflows/test_package.yml
index d04345d..b2ab4bb 100644
--- a/.github/workflows/test_package.yml
+++ b/.github/workflows/test_package.yml
@@ -22,10 +22,14 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
+ - name: Set up R
+ uses: r-lib/actions/setup-r@v2
+ with:
+ r-version: '4.3.2'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements_dev.txt
- name: Test with pytest
run: |
- python -m pytest
+ python -m pytest --cov=causalAssembly tests/
diff --git a/.lintr b/.lintr
new file mode 100644
index 0000000..8c59a64
--- /dev/null
+++ b/.lintr
@@ -0,0 +1,5 @@
+linters: linters_with_defaults(
+ line_length_linter = line_length_linter(120L),
+ object_name_linter = NULL,
+ indentation_linter = NULL
+ )
diff --git a/Makefile b/Makefile
index 612ffb9..d531305 100644
--- a/Makefile
+++ b/Makefile
@@ -1,28 +1,50 @@
#Makefile for causalAssembly
+
MAKEFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST)))
CURRENT_DIR := $(notdir $(patsubst %/,%,$(dir $(MAKEFILE_PATH))))
VENV_NAME := venv_${CURRENT_DIR}
PYTHON=${VENV_NAME}/bin/python
+ALLOWED_LICENSES := "$(shell tr -s '\n' ';' < allowed_licenses.txt)"
+ALLOWED_PACKAGES := $(shell tr -s '\n' ' ' < allowed_packages.txt)
+
sync-venv:
: # Create or update default virtual environment to latest pinned
: # dependencies
test -d $(VENV_NAME) || \
python3.10 -m virtualenv $(VENV_NAME); \
- ${PYTHON} -m pip install -U pip; \
- ${PYTHON} -m pip install pip-tools
- . $(VENV_NAME)/bin/activate && pip-sync requirements_dev.txt
- #. $(VENV_NAME)/bin/activate && pip install --no-deps -e .
+ ${PYTHON} -m pip install -U uv; \
+ . $(VENV_NAME)/bin/activate && uv pip sync requirements_dev.txt
+ . $(VENV_NAME)/bin/activate && uv pip install --no-deps -e .
requirements:
: # Update requirements_dev.txt if only new library is added
: # Assumes virtual environment with pip-tools installed is activated
- pip-compile --extra dev -o requirements_dev.txt pyproject.toml --annotation-style line --no-emit-index-url --no-emit-trusted-host --allow-unsafe --resolver=backtracking
+ uv pip compile pyproject.toml --output-file=requirements.txt --annotation-style=line
+ uv pip compile pyproject.toml --extra=dev --output-file=requirements_dev.txt --annotation-style=line
update-requirements:
: # Update requirements_dev.txt if dependencies should be updated
: # Assumes virtual environment with pip-tools installed is activated
- pip-compile --extra dev -o requirements_dev.txt pyproject.toml --annotation-style line --no-emit-index-url --no-emit-trusted-host --allow-unsafe --resolver=backtracking --upgrade
+ uv pip compile pyproject.toml --output-file=requirements.txt --annotation-style=line --upgrade
+ uv pip compile pyproject.toml --extra=dev --output-file=requirements_dev.txt --annotation-style=line --upgrade
+
+
+precommit:
+ : # Run precommit on all files locally (this runs only the pre-commit and not the pre-push hooks)
+ pre-commit install
+ pre-commit run --all-files
+
+test:
+ : # Run pytest
+ python -m pytest
+
+licenses:
+ : # Create file with used licenses and check if these are permissive
+ pip-licenses --order=license --ignore-packages ${ALLOWED_PACKAGES} --allow-only=${ALLOWED_LICENSES} > licenses.txt
+
+clean-branches:
+ git branch --merged | grep -v '\\*\\|main\\|develop' | xargs -n 1 -r git branch -d
diff --git a/README.md b/README.md
index 3585308..30e976f 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@
This repo provides details regarding $\texttt{causalAssembly}$, a causal discovery benchmark data tool based on complex production data.
Theoretical details and information regarding construction are presented in the [paper](https://arxiv.org/abs/2306.10816):
- Göbler, K., Windisch, T., Pychynski, T., Sonntag, S., Roth, M., & Drton, M. causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery, to appear in Proceedings of the 3rd Conference on Causal Learning and Reasoning (CLeaR), 2024,
+ Göbler, K., Windisch, T., Pychynski, T., Sonntag, S., Roth, M., & Drton, M. causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery, to appear in Proceedings of the 3rd Conference on Causal Learning and Reasoning (CLeaR), 2024,
## Authors
* [Konstantin Goebler](mailto:konstantin.goebler@de.bosch.com)
* [Steffen Sonntag](mailto:steffen.sonntag@de.bosch.com)
diff --git a/VERSION b/VERSION
index 781dcb0..26aaba0 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-1.1.3
+1.2.0
diff --git a/allowed_licenses.txt b/allowed_licenses.txt
new file mode 100644
index 0000000..dc8290e
--- /dev/null
+++ b/allowed_licenses.txt
@@ -0,0 +1,24 @@
+Apache-2.0
+Apache 2.0
+Apache License 2.0
+APACHE SOFTWARE LICENSE
+BOSCH-INTERNAL
+BSD 3-Clause
+BSD License
+BSD
+3-Clause BSD License
+BSD (3-Clause)
+ISC
+ISC License (ISCL)
+Other/Proprietary License
+MIT
+MIT License
+CMU License (MIT-CMU)
+GPL-2.0-or-later
+GNU General Public License v2 or later (GPLv2+)
+Python Software Foundation License
+The Unlicense (Unlicense)
+Historical Permission Notice and Disclaimer (HPND)
+MPL-2.0
+NVIDIA Proprietary Software
+UNKNOWN
diff --git a/allowed_packages.txt b/allowed_packages.txt
new file mode 100644
index 0000000..f7aa7de
--- /dev/null
+++ b/allowed_packages.txt
@@ -0,0 +1,8 @@
+causalAssembly
+rpy2
+chardet
+certifi
+dbx
+mdpcommonlib
+pathspec
+text-unidecode
diff --git a/benchmarks/run.py b/benchmarks/run.py
index eccbc36..cce3de5 100644
--- a/benchmarks/run.py
+++ b/benchmarks/run.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
diff --git a/benchmarks/utils.py b/benchmarks/utils.py
index adbfc43..f68b1aa 100644
--- a/benchmarks/utils.py
+++ b/benchmarks/utils.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -12,11 +13,13 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
"""
+
import logging
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from itertools import repeat
+from typing import Any
import lingam
import networkx as nx
@@ -40,12 +43,15 @@
class BenchMarker:
- """Class to run causal discovery benchmarks. One instance can be
+ """Class to run causal discovery benchmarks.
+
+ One instance can be
the basis for several benchmark runs based on the same settings
with different data generator objects.
"""
def __init__(self):
+ """Initializes the BenchMarker class."""
self.algorithms: dict = {"snr": BenchMarker._fit_snr}
self.collect_results: dict = {}
self.num_runs: int
@@ -54,37 +60,37 @@ def __init__(self):
self.n_select: int
def include_pc(self):
- """Includes the PC-stable algorithm from the `causal-learn` package"""
+ """Includes the PC-stable algorithm from the `causal-learn` package."""
logger.info("PC algorithm added to benchmark routines.")
self.algorithms["pc"] = BenchMarker._fit_pc
def include_ges(self):
- """Includes GES from the `causal-learn` package"""
+ """Includes GES from the `causal-learn` package."""
logger.info("GES algorithm added to benchmark routines.")
self.algorithms["ges"] = BenchMarker._fit_ges
def include_notears(self):
- """Includes the NOTEARS algorithm from the `gcastle` package"""
+ """Includes the NOTEARS algorithm from the `gcastle` package."""
logger.info("NOTEARS added to benchmark routines.")
self.algorithms["notears"] = BenchMarker._fit_notears
def include_grandag(self):
- """Includes the Gran-DAG algorithm from the `gcastle` package"""
+ """Includes the Gran-DAG algorithm from the `gcastle` package."""
logger.info("Gran-DAG added to benchmark routines.")
self.algorithms["grandag"] = BenchMarker._fit_grandag
def include_score(self):
- """Includes the SCORE algorithm from the `dodiscovery` package"""
+ """Includes the SCORE algorithm from the `dodiscovery` package."""
logger.info("SCORE algorithm added to benchmark routines.")
self.algorithms["score"] = BenchMarker._fit_score
def include_das(self):
- """Includes the DAS algorithm from the `dodiscovery` package"""
+ """Includes the DAS algorithm from the `dodiscovery` package."""
logger.info("DAS algorithm added to benchmark routines.")
self.algorithms["das"] = BenchMarker._fit_das
def include_lingam(self):
- """Includes the DirectLiNGAM algorithm from the `lingam` package"""
+ """Includes the DirectLiNGAM algorithm from the `lingam` package."""
logger.info("Direct LiNGAM added to benchmark routines.")
self.algorithms["lingam"] = BenchMarker._fit_lingam
@@ -112,7 +118,7 @@ def _causallearn2amat(causal_learn_graph: np.ndarray) -> np.ndarray:
if causal_learn_graph[row, col] == -1 and causal_learn_graph[col, row] == -1:
amat[row, col] = amat[col, row] = 1
if causal_learn_graph[row, col] == 1 and causal_learn_graph[col, row] == 1:
- logger.warning(f"ambiguity found in {(row,col)}. I'll make it bidirected")
+ logger.warning(f"ambiguity found in {(row, col)}. I'll make it bidirected")
amat[row, col] = amat[col, row] = 1
return amat
@@ -131,9 +137,12 @@ def _fit_ges(data: pd.DataFrame) -> PDAG:
# Beware of the Simulated DAG! Causal Discovery Benchmarks May Be Easy To Game.
@staticmethod
def _fit_snr(data: pd.DataFrame) -> pd.DataFrame:
- """Take n x d data, order nodes by marginal variance and
+ """SNR algo.
+
+ Take n x d data, order nodes by marginal variance and
regresses each node onto those with lower variance, using
- edge coefficients as structure estimates."""
+ edge coefficients as structure estimates.
+ """
X = data.to_numpy()
LR = LinearRegression()
LL = LassoLarsIC(criterion="bic")
@@ -158,10 +167,13 @@ def _fit_snr(data: pd.DataFrame) -> pd.DataFrame:
# Beware of the Simulated DAG! Causal Discovery Benchmarks May Be Easy To Game.
@staticmethod
def varsortability(data: pd.DataFrame, ground_truth: pd.DataFrame, tol=1e-9):
- """Takes n x d data and a d x d adjaceny matrix,
+ """Varsortability algo.
+
+ Takes n x d data and a d x d adjaceny matrix,
where the i,j-th entry corresponds to the edge weight for i->j,
and returns a value indicating how well the variance order
- reflects the causal order."""
+ reflects the causal order.
+ """
X = data.to_numpy()
W = ground_truth.to_numpy()
E = W != 0
@@ -226,7 +238,7 @@ def _fit_das(data: pd.DataFrame) -> pd.DataFrame:
def run_benchmark(
self,
runs: int,
- prod_obj: ProductionLineGraph | ProcessCell,
+ prod_obj: ProductionLineGraph,
n_select: int = 500,
harmonize_via: str | None = "cpdag_transform",
size_threshold: int = 50,
@@ -252,11 +264,24 @@ def run_benchmark(
`"best_dag_shd"` is selected, all DAGs in the implied MEC will be enumerated, the
SHD calculated and the lowest (best) candidate DAG chosen. Defaults to
"cpdag_transform".
+ size_threshold (int) : size of threshold.
parallelize (bool, optional): Whether to run on parallel processes. Defaults to False.
n_workers (int, optional): If `parallelize = True` you need to assign the number
of workers to prarallelize over. Defaults to 4.
seed_sequence (int, optional): If `parallelize = True` you may choose the seed sequence
handed down to every parallel process. Defaults to 1234.
+ chunksize (int | None): If `parallelize = True` you may choose the
+ chunksize for the parallelization. If `None`, it will be set to
+ `runs / n_workers` or 1, whichever is larger. Defaults to None.
+ external_dfs (list[pd.DataFrame | np.ndarray] | None, optional):
+ If you want to use external dataframes for the benchmark runs, you can pass a list
+ of dataframes or numpy arrays here. The length of the list must match the number of
+ runs. If `None`, the `prod_obj` will be used to sample data.
+ Defaults to None.
+ between_and_within_results (bool, optional): If `True`, the benchmark will also
+ return the within and between metrics for the `prod_obj` if it is a
+ `ProductionLineGraph`. If `False`, only the metrics for the overall graph will be
+ returned. Defaults to False.
"""
self.num_runs = runs
self.prod_object = prod_obj
@@ -359,9 +384,9 @@ def run_benchmark(
@staticmethod
def single_run(
- new_seed: np.random.BitGenerator | None,
- prod_obj: ProductionLineGraph | ProcessCell,
- algorithms: list[str],
+ new_seed: int | None,
+ prod_obj: ProductionLineGraph,
+ algorithms: dict[str, Any],
child_seed: None | np.random.SeedSequence = None,
harmonize_via: None | str = "cpdag_transform",
n_select: int = 500,
@@ -369,29 +394,37 @@ def single_run(
external_df: pd.DataFrame | np.ndarray | None = None,
between_and_within_results: bool = False,
):
- """Single benchmark run
+ """Single benchmark run.
Args:
- new_seed (np.random.BitGenerator | None): seed
- prod_obj (ProductionLineGraph | ProcessCell): object of interest
- algorithms (list[str]): cd algs
- child_seed (None | np.random.SeedSequence, optional): Defaults to None.
- harmonize_via (None | str, optional): Defaults to "cpdag_transform".
- n_select (int, optional): Defaults to 500.
- size_threshold (int, optional): Defaults to 50.
+ new_seed (int | None): _description_
+ prod_obj (ProductionLineGraph): _description_
+ algorithms (dict[str, Any]): _description_
+ child_seed (None | np.random.SeedSequence, optional): _description_. Defaults to None.
+ harmonize_via (None | str, optional): _description_. Defaults to "cpdag_transform".
+ n_select (int, optional): _description_. Defaults to 500.
+ size_threshold (int, optional): _description_. Defaults to 50.
+ external_df (pd.DataFrame | np.ndarray | None, optional):
+ _description_. Defaults to None.
+ between_and_within_results (bool, optional): _description_. Defaults to False.
Raises:
- AssertionError
- AssertionError
- TypeError
+ AssertionError: _description_
+ AssertionError: _description_
+ TypeError: _description_
+ AssertionError: _description_
Returns:
- dict: results from that run.
+ _type_: _description_
"""
metrics = defaultdict(partial(defaultdict, list))
within_between_metrics = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
- if child_seed:
- prod_obj.random_state = np.random.default_rng(seed=child_seed[new_seed])
+ if child_seed is not None:
+ num_seeds = 10
+ child_seeds = child_seed.spawn(num_seeds)
+ if new_seed is None:
+ new_seed = 0 # Or some default
+ prod_obj.random_state = np.random.default_rng(child_seeds[new_seed])
if external_df is not None:
if isinstance(external_df, pd.DataFrame):
@@ -428,7 +461,9 @@ def single_run(
dag_metrics = DAGmetrics(truth=prod_obj.graph, est=dag)
all_shds.append(dag_metrics._shd())
- absolute_distance_to_mean = np.abs(all_shds - np.mean(all_shds))
+ absolute_distance_to_mean = np.abs(
+ np.array(all_shds) - np.mean(all_shds)
+ )
random_index_choice = np.random.choice(
np.flatnonzero(
absolute_distance_to_mean == np.min(absolute_distance_to_mean)
@@ -444,18 +479,17 @@ def single_run(
ground_truth, pd.DataFrame
):
raise AssertionError("something went wrong in the best DAG selection")
- else:
- if type(ground_truth) is not type(result):
- raise TypeError("ground truth and results need to have the same instance.")
+ elif type(ground_truth) is not type(result):
+ raise TypeError("ground truth and results need to have the same instance.")
get_metrics = DAGmetrics(truth=ground_truth, est=result)
- get_metrics.collect_metrics()
+ my_metrics = get_metrics.collect_metrics()
if harmonize_via == "best_dag_shd":
target_dag = prod_obj.graph
result_dag = nx.from_pandas_adjacency(df=result, create_using=nx.DiGraph)
- get_metrics.metrics["sid"] = int(SID(target=target_dag, pred=result_dag))
+ my_metrics["sid"] = int(SID(target=target_dag, pred=result_dag))
if harmonize_via == "cpdag_transform":
- get_metrics.metrics["shd"] = get_metrics._shd(count_anticausal_twice=False)
+ my_metrics["shd"] = get_metrics._shd(count_anticausal_twice=False)
if between_and_within_results and harmonize_via == "best_dag_shd":
if isinstance(prod_obj, ProcessCell):
@@ -473,24 +507,24 @@ def single_run(
get_within_metrics = DAGmetrics(
truth=within_metrics, est=result_plg.within_adjacency
)
- get_within_metrics.collect_metrics()
+ within_metrics = get_within_metrics.collect_metrics()
get_between_metrics = DAGmetrics(
truth=between_metrics, est=result_plg.between_adjacency
)
- get_between_metrics.collect_metrics()
+ between_metrics = get_between_metrics.collect_metrics()
# make dict
- w_b_dict = {"within": get_within_metrics, "between": get_between_metrics}
+ w_b_dict = {"within": within_metrics, "between": between_metrics}
for metric_name in ["precision", "recall"]:
for which_one, dct in w_b_dict.items():
within_between_metrics[alg_name][metric_name][which_one].append(
- dct.metrics[metric_name]
+ dct[metric_name]
)
- for metric_name, _ in get_metrics.metrics.items():
- metrics[alg_name][metric_name].append(get_metrics.metrics[metric_name])
+ for metric_name, _ in my_metrics.items():
+ metrics[alg_name][metric_name].append(my_metrics[metric_name])
if between_and_within_results:
return {
@@ -503,7 +537,7 @@ def single_run(
def causallearn2amat(causal_learn_graph: np.ndarray) -> np.ndarray:
- """Causallearn object helper function
+ """Causallearn object helper function.
Args:
causal_learn_graph (np.ndarray): causal lean output graph
@@ -522,8 +556,7 @@ def causallearn2amat(causal_learn_graph: np.ndarray) -> np.ndarray:
def to_dag_amat(cpdag_amat: pd.DataFrame) -> pd.DataFrame:
- """Turns PDAG into random member of the corresponding
- Markov equivalence class.
+ """Turns PDAG into random member of the corresponding Markov equivalence class.
Args:
cpdag_amat (pd.DataFrame): PDAG representing the MEC
@@ -531,7 +564,6 @@ def to_dag_amat(cpdag_amat: pd.DataFrame) -> pd.DataFrame:
Returns:
pd.DataFrame: DAG as member of MEC.
"""
-
pdag = PDAG.from_pandas_adjacency(cpdag_amat)
chosen_dag = pdag.to_dag()
if not nx.is_directed_acyclic_graph(chosen_dag):
diff --git a/causalAssembly/__init__.py b/causalAssembly/__init__.py
index 03bfac7..bf2a89f 100644
--- a/causalAssembly/__init__.py
+++ b/causalAssembly/__init__.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
diff --git a/causalAssembly/dag.py b/causalAssembly/dag.py
index 513ec79..6be1ca9 100644
--- a/causalAssembly/dag.py
+++ b/causalAssembly/dag.py
@@ -1,4 +1,19 @@
-"""DAG class"""
+"""DAG class.
+
+Copyright (c) 2023 Robert Bosch GmbH
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as published
+by the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+You should have received a copy of the GNU Affero General Public License
+along with this program. If not, see .
+"""
+
from __future__ import annotations
import logging
@@ -19,16 +34,22 @@
class DAG:
- """
- General class for dealing with directed acyclic graph i.e.
+ """General class for dealing with directed acyclic graph i.e.
+
graphs that are directed and must not contain any cycles.
"""
def __init__(
self,
- nodes: list | None = None,
- edges: list[tuple] | None = None,
+ nodes: list[str] | list[int] | set[int | str] | None = None,
+ edges: list[tuple[str | int, str | int]] | set[tuple[str | int, str | int]] | None = None,
):
+ """Initialize DAG.
+
+ Args:
+ nodes (list | None, optional): _description_. Defaults to None.
+ edges (list[tuple] | None, optional): _description_. Defaults to None.
+ """
if nodes is None:
nodes = []
if edges is None:
@@ -50,7 +71,6 @@ def _add_edge(self, i, j):
self._edges.add((i, j))
# Check if graph is acyclic
- # TODO: Make check really after each edge is added?
if not self.is_acyclic():
raise ValueError(
"The edge set you provided \
@@ -63,6 +83,11 @@ def _add_edge(self, i, j):
@property
def random_state(self):
+ """Random state.
+
+ Returns:
+ _type_: _description_
+ """
return self._random_state
@random_state.setter
@@ -72,7 +97,7 @@ def random_state(self, r: np.random.Generator):
self._random_state = r
def add_edge(self, edge: tuple[str, str]):
- """Add edge to DAG
+ """Add edge to DAG.
Args:
edge (tuple[str, str]): Edge to add
@@ -80,7 +105,7 @@ def add_edge(self, edge: tuple[str, str]):
self._add_edge(*edge)
def add_edges_from(self, edges: list[tuple[str, str]]):
- """Add multiple edges to DAG
+ """Add multiple edges to DAG.
Args:
edges (list[tuple[str, str]]): Edges to add
@@ -129,8 +154,7 @@ def induced_subgraph(self, nodes: list[str]) -> DAG:
return DAG(nodes=nodes, edges=edges)
def is_adjacent(self, i: str, j: str) -> bool:
- """Return True if the graph contains an directed
- edge between i and j.
+ """Return True if the graph contains an directed edge between i and j.
Args:
i (str): node i.
@@ -142,9 +166,7 @@ def is_adjacent(self, i: str, j: str) -> bool:
return (j, i) in self.edges or (i, j) in self.edges
def is_clique(self, potential_clique: set) -> bool:
- """
- Check every pair of node X potential_clique is adjacent.
- """
+ """Check every pair of node X potential_clique is adjacent."""
return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))
def is_acyclic(self) -> bool:
@@ -167,7 +189,7 @@ def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> DAG:
DAG
"""
assert pd_amat.shape[0] == pd_amat.shape[1]
- nodes = pd_amat.columns
+ nodes = list(pd_amat.columns)
all_connections = []
start, end = np.where(pd_amat != 0)
@@ -182,7 +204,7 @@ def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> DAG:
return DAG(nodes=nodes, edges=dir_edges)
def remove_edge(self, i: str, j: str):
- """Removes edge in question
+ """Removes edge in question.
Args:
i (str): tail
@@ -199,10 +221,10 @@ def remove_edge(self, i: str, j: str):
self._parents[j].discard(i)
def remove_node(self, node):
- """Remove a node from the graph"""
+ """Remove a node from the graph."""
self._nodes.remove(node)
- self._edges = {(i, j) for i, j in self._edges if i != node and j != node}
+ self._edges = {(i, j) for i, j in self._edges if node not in (i, j)}
for child in self._children[node]:
self._parents[child].remove(node)
@@ -215,8 +237,9 @@ def remove_node(self, node):
@property
def adjacency_matrix(self) -> pd.DataFrame:
- """Returns adjacency matrix where the i,jth
- entry being one indicates that there is an edge
+ """Returns adjacency matrix.
+
+ The i,jth entry being one indicates that there is an edge
from i to j. A zero indicates that there is no edge.
Returns:
@@ -232,7 +255,7 @@ def adjacency_matrix(self) -> pd.DataFrame:
return amat
def vstructs(self) -> set:
- """Retrieve v-structures
+ """Retrieve v-structures.
Returns:
set: set of all v-structures
@@ -246,7 +269,7 @@ def vstructs(self) -> set:
return vstructures
def copy(self):
- """Return a copy of the graph"""
+ """Return a copy of the graph."""
return DAG(nodes=self._nodes, edges=self._edges)
def show(self):
@@ -287,8 +310,7 @@ def num_nodes(self) -> int:
@property
def num_edges(self) -> int:
- """Number of directed edges
- in current DAG.
+ """Number of directed edges in current DAG.
Returns:
int: Number of directed edges
@@ -297,7 +319,7 @@ def num_edges(self) -> int:
@property
def sparsity(self) -> float:
- """Sparsity of the graph
+ """Sparsity of the graph.
Returns:
float: in [0,1]
@@ -307,8 +329,7 @@ def sparsity(self) -> float:
@property
def edges(self) -> list[tuple]:
- """Gives all directed edges in
- current DAG.
+ """Gives all directed edges in current DAG.
Returns:
list[tuple]: List of directed edges.
@@ -318,6 +339,7 @@ def edges(self) -> list[tuple]:
@property
def causal_order(self) -> list[str]:
"""Returns the causal order of the current graph.
+
Note that this order is in general not unique.
Returns:
@@ -360,7 +382,7 @@ def from_nx(cls, nx_dag: nx.DiGraph) -> DAG:
raise TypeError("DAG must be of type nx.DiGraph")
return DAG(nodes=list(nx_dag.nodes), edges=list(nx_dag.edges))
- def save_drf(self, filename: str, location: str = None):
+ def save_drf(self, filename: str, location: str | Path | None = None):
"""Writes a drf dict to file. Please provide the .pkl suffix!
Args:
@@ -368,8 +390,7 @@ def save_drf(self, filename: str, location: str = None):
location (str, optional): path to file in case it's not located in
the current working directory. Defaults to None.
"""
-
- if not location:
+ if location is None:
location = Path().resolve()
location_path = Path(location, filename)
@@ -391,10 +412,15 @@ def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
return _sample_from_drf(graph=self, size=size, smoothed=smoothed)
def to_cpdag(self) -> PDAG:
+ """Conversion to CPDAG.
+
+ Returns:
+ PDAG: _description_
+ """
return dag2cpdag(dag=self.to_networkx())
@classmethod
- def load_drf(cls, filename: str, location: str = None) -> dict:
+ def load_drf(cls, filename: str, location: str | Path | None = None) -> dict:
"""Loads a drf dict from a .pkl file into the workspace.
Args:
diff --git a/causalAssembly/dag_utils.py b/causalAssembly/dag_utils.py
index 144ba9c..29c0954 100644
--- a/causalAssembly/dag_utils.py
+++ b/causalAssembly/dag_utils.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -12,6 +13,7 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
"""
+
import itertools
import logging
@@ -43,7 +45,6 @@ def merge_dags(
Returns:
nx.DiGraph: merged DAG
"""
-
for old_node_name, new_node_name in mapping.items():
if new_node_name not in target_dag.nodes():
raise ValueError(f"{new_node_name} does not exist in target_dag")
@@ -74,7 +75,7 @@ def merge_dags(
def merge_dags_via_edges(
left_dag: nx.DiGraph,
right_dag: nx.DiGraph,
- edges: list[tuple] = None,
+ edges: list[tuple] | None = None,
isolate_target_nodes: bool = False,
):
"""Merges two dags via a list of edges.
@@ -119,7 +120,6 @@ def merge_dags_via_edges(
merged_dag = nx.compose(left_dag, right_dag)
- # TODO experimental
merged_dag.add_edges_from(edges, **{"connector": True})
return merged_dag
@@ -143,7 +143,6 @@ def tuples_from_cartesian_product(l1: list, l2: list) -> list[tuple]:
[(0,'a'), (0,'b'), (0,'c'), (1,'a'), (1,'b'), (1,'c'), (2,'a'), (2,'b'), (2,'c')]
"""
- # TODO: This could take a long time for large graphs...
return [
(tail, head)
for tail, head in itertools.product(
@@ -153,11 +152,15 @@ def tuples_from_cartesian_product(l1: list, l2: list) -> list[tuple]:
]
-def _bootstrap_sample(rng: np.random.Generator, data: np.array, size: int = None) -> np.array:
+def _bootstrap_sample(
+ rng: np.random.Generator, data: np.ndarray, size: int | None = None
+) -> np.ndarray:
"""Generate bootstrap sample, i.e.
+
random sample with replacement of length `size` from 1-d array.
Args:
+ rng (np.random.Generator): Random number generator
data (np.array): 1-d array
size (int, optional): Size of bootstrap sample. If set to `None`,
we set size to length of input array
diff --git a/causalAssembly/drf_fitting.py b/causalAssembly/drf_fitting.py
index f22acc8..c57e2e7 100644
--- a/causalAssembly/drf_fitting.py
+++ b/causalAssembly/drf_fitting.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -12,6 +13,7 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
"""
+
from __future__ import annotations
import numpy as np
@@ -32,16 +34,18 @@
class DRF:
- """Wrapper around the corresponding R package:
+ """Wrapper around the corresponding R package.
+
Distributional Random Forests (Cevid et al., 2020).
- Closely adopted from their python wrapper."""
+ Closely adopted from their python wrapper.
+ """
def __init__(self, **fit_params):
+ """Initialize DRF object."""
self.fit_params = fit_params
- def fit(self, X: pd.DataFrame, Y: pd.DataFrame):
- """Fit DRF in order to estimate conditional
- distribution P(Y|X=x).
+ def fit(self, X: pd.DataFrame, Y: pd.DataFrame | pd.Series):
+ """Fit DRF in order to estimate conditional distribution P(Y|X=x).
Args:
X (pd.DataFrame): Conditioning set.
@@ -88,9 +92,10 @@ def produce_sample(
def fit_drf(graph: ProductionLineGraph | ProcessCell | DAG, data: pd.DataFrame):
- """Fit distributional random forests to the
- factorization implied by the current graph
+ """Fit distributional random forests to the factorization implied by the current graph.
+
Args:
+ graph (ProductionLineGraph | ProcessCell | DAG): Graph to fit the DRF to.
data (pd.DataFrame): Columns of dataframe need to match name and order of the graph
Raises:
diff --git a/causalAssembly/metrics.py b/causalAssembly/metrics.py
index 0aeabb8..940d93f 100644
--- a/causalAssembly/metrics.py
+++ b/causalAssembly/metrics.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -12,6 +13,7 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
"""
+
import copy
import networkx as nx
@@ -21,6 +23,7 @@
class DAGmetrics:
"""Class to calculate performance metrics for DAGs.
+
Make sure that the ground truth and the estimated DAG have the same order of
rows/columns. If these objects are nx.DiGraphs, make sure that graph.nodes()
have the same oder or pass a new nodelist to the class when initiating. The
@@ -33,8 +36,19 @@ def __init__(
self,
truth: nx.DiGraph | pd.DataFrame | np.ndarray,
est: nx.DiGraph | pd.DataFrame | np.ndarray,
- nodelist: list = None,
+ nodelist: list[str] | None = None,
):
+ """Inits the DAGmetrics class.
+
+ Args:
+ truth (nx.DiGraph | pd.DataFrame | np.ndarray): _description_
+ est (nx.DiGraph | pd.DataFrame | np.ndarray): _description_
+ nodelist (list, optional): _description_. Defaults to None.
+
+ Raises:
+ TypeError: _description_
+ TypeError: _description_
+ """
if not isinstance(truth, nx.DiGraph | pd.DataFrame | np.ndarray):
raise TypeError("Ground truth graph has to be one of the permitted classes.")
@@ -47,7 +61,7 @@ def __init__(
self.metrics = None
def _calculate_scores(self):
- """Calculate Precision, Recall and F1 and g score
+ """Calculate Precision, Recall and F1 and g score.
Return:
precision: float
@@ -59,8 +73,9 @@ def _calculate_scores(self):
gscore: float
max(0, (TP-FP))/(TP+FN)
"""
+ TWO = 2
assert self.est.shape == self.truth.shape and self.est.shape[0] == self.est.shape[1]
- TP = np.where((self.est + self.truth) == 2, 1, 0).sum(axis=1).sum()
+ TP = np.where((self.est + self.truth) == TWO, 1, 0).sum(axis=1).sum()
TP_FP = self.est.sum(axis=1).sum()
FP = TP_FP - TP
TP_FN = self.truth.sum(axis=1).sum()
@@ -96,12 +111,13 @@ def collect_metrics(self) -> dict[str, float | int]:
metrics = self._calculate_scores()
metrics["shd"] = self._shd()
self.metrics = metrics
+ return metrics
@classmethod
def _convert_to_numpy(
cls,
graph: nx.DiGraph | pd.DataFrame | np.ndarray,
- nodelist: list = None,
+ nodelist: list[str] | None = None,
):
if isinstance(graph, np.ndarray):
return copy.deepcopy(graph)
diff --git a/causalAssembly/models_dag.py b/causalAssembly/models_dag.py
index 57bf614..d4d2c3d 100644
--- a/causalAssembly/models_dag.py
+++ b/causalAssembly/models_dag.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -12,6 +13,7 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
"""
+
from __future__ import annotations
import itertools
@@ -45,6 +47,8 @@
@dataclass
class NodeAttributes:
+ """Node Attributes."""
+
ALLOW_IN_EDGES = "allow_in_edges"
HIDDEN = "is_hidden"
@@ -64,8 +68,12 @@ def _sample_from_drf(
size=size, seed=prod_object.random_state
)[0]
else:
+ if prod_object.random_state is not None:
+ rng = prod_object.random_state
+ else:
+ rng = np.random.default_rng()
sample_dict[node] = _bootstrap_sample(
- rng=prod_object.random_state,
+ rng=rng,
data=prod_object.drf[node].dataset[0],
size=size,
)
@@ -154,8 +162,8 @@ def _interventional_sample_from_drf(
class ProcessCell:
- """
- Representation of a single Production Line Cell
+ """Representation of a single Production Line Cell.
+
(to model a station / a process in a production line
environment).
@@ -168,6 +176,11 @@ class ProcessCell:
"""
def __init__(self, name: str):
+ """Inits Process cell class.
+
+ Args:
+ name (str): _description_
+ """
self.name = name
self.graph: nx.DiGraph = nx.DiGraph()
@@ -201,7 +214,7 @@ def edges(self) -> list[tuple]:
@property
def num_nodes(self) -> int:
- """Number of nodes in the graph
+ """Number of nodes in the graph.
Returns:
int
@@ -210,7 +223,7 @@ def num_nodes(self) -> int:
@property
def num_edges(self) -> int:
- """Number of edges in the graph
+ """Number of edges in the graph.
Returns:
int
@@ -219,7 +232,7 @@ def num_edges(self) -> int:
@property
def sparsity(self) -> float:
- """Sparsity of the graph
+ """Sparsity of the graph.
Returns:
float: in [0,1]
@@ -229,8 +242,7 @@ def sparsity(self) -> float:
@property
def ground_truth(self) -> pd.DataFrame:
- """Returns the current ground truth as
- pandas adjacency.
+ """Returns the current ground truth as pandas adjacency.
Returns:
pd.DataFrame: Adjacenccy matrix.
@@ -240,6 +252,7 @@ def ground_truth(self) -> pd.DataFrame:
@property
def causal_order(self) -> list[str]:
"""Returns the causal order of the current graph.
+
Note that this order is in general not unique.
Returns:
@@ -258,7 +271,7 @@ def parents(self, of_node: str) -> list[str]:
"""
return list(self.graph.predecessors(of_node))
- def save_drf(self, filename: str, location: str = None):
+ def save_drf(self, filename: str, location: str | Path | None = None):
"""Writes a drf dict to file. Please provide the .pkl suffix!
Args:
@@ -266,7 +279,6 @@ def save_drf(self, filename: str, location: str = None):
location (str, optional): path to file in case it's not located in
the current working directory. Defaults to None.
"""
-
if not location:
location = Path().resolve()
@@ -281,7 +293,7 @@ def add_module(
allow_in_edges: bool = True,
mark_hidden: bool | list = False,
) -> str:
- """Adds module to cell graph. Module has to be as nx.DiGraph object
+ """Adds module to cell graph. Module has to be as nx.DiGraph object.
Args:
graph (nx.DiGraph): Graph to add to cell.
@@ -296,15 +308,13 @@ def add_module(
Returns:
str: prefix of Module created
"""
-
next_module_prefix = self.next_module_prefix()
node_renaming_dict = {
old_node_name: f"{self.name}_{next_module_prefix}_{old_node_name}"
for old_node_name in graph.nodes()
}
-
- self.modules[self.next_module_prefix()] = graph.copy()
+ self.modules[self.next_module_prefix()] = graph.copy() # type: ignore
graph = nx.relabel_nodes(graph, node_renaming_dict)
if allow_in_edges: # for later: mark nodes to not have incoming edges
@@ -323,14 +333,14 @@ def add_module(
nx.set_node_attributes(
graph, values=overwrite_dict
) # only overwrite the ones specified
- # TODO relabel attributes, i.e. name of the parents has changed now?
- # .update_attributes or so or keep and remove prefixes in bayesian network creation?
self.graph = nx.compose(self.graph, graph)
return next_module_prefix
def input_cellgraph_directly(self, graph: nx.DiGraph, allow_in_edges: bool = False):
- """Allow to input graphs on a cell-level. This should only be done if the graph
+ """Allow to input graphs on a cell-level.
+
+ This should only be done if the graph
is already available for the entire cell, otherwise `add_module()` is preferred.
Args:
@@ -362,23 +372,10 @@ def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
"""
return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)
- def interventional_sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
- """Draw from the trained DRF.
-
- Args:
- size (int, optional): Number of samples to be drawn. Defaults to 10.
- smoothed (bool, optional): If set to true, marginal distributions will
- be sampled from smoothed bootstraps. Defaults to True.
-
- Returns:
- pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
- """
- return _interventional_sample_from_drf(prod_object=self, size=size, smoothed=smoothed)
-
def _generate_random_dag(self, n_nodes: int = 5, p: float = 0.1) -> nx.DiGraph:
- """
- Creates a random DAG by
- taking an arbitrary ordering of the specified number of nodes,
+ """Creates a random DAG.
+
+ By taking an arbitrary ordering of the specified number of nodes,
and then considers edges from node i to j only if i < j.
That constraint leads to DAGness by construction.
@@ -389,26 +386,37 @@ def _generate_random_dag(self, n_nodes: int = 5, p: float = 0.1) -> nx.DiGraph:
Returns:
nx.DiGraph:
"""
+ rng = self.random_state
+ if rng is None:
+ rng = np.random.default_rng()
dag = nx.DiGraph()
dag.add_nodes_from(range(0, n_nodes))
causal_order = list(dag.nodes)
- self.random_state.shuffle(causal_order)
+ rng.shuffle(causal_order)
all_forward_edges = itertools.combinations(causal_order, 2)
edges = np.array(list(all_forward_edges))
- random_choice = self.random_state.choice([False, True], p=[1 - p, p], size=edges.shape[0])
+ random_choice = rng.choice([False, True], p=[1 - p, p], size=edges.shape[0])
dag.add_edges_from(edges[random_choice])
return dag
def add_random_module(self, n_nodes: int = 7, p: float = 0.10):
+ """Add random module to the cell.
+
+ Args:
+ n_nodes (int, optional): _description_. Defaults to 7.
+ p (float, optional): _description_. Defaults to 0.10.
+ """
randomdag = self._generate_random_dag(n_nodes=n_nodes, p=p)
self.add_module(graph=randomdag, allow_in_edges=True, mark_hidden=False)
def connect_by_module(self, m1: str, m2: str, edges: list[tuple]):
- """Connect two modules (by name, e.g. M2, M4) of the cell by a list
+ """Connect two modules.
+
+ (by name, e.g. M2, M4) of the cell by a list
of edges with the original node names.
Args:
@@ -431,8 +439,9 @@ def connect_by_module(self, m1: str, m2: str, edges: list[tuple]):
self.graph.add_edges_from(new_edges)
def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph:
- """
- Add random edges to graph according to proportion,
+ """Add random edges to graph.
+
+ according to proportion
with restriction specified in node attributes.
Args:
@@ -445,7 +454,9 @@ def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph:
Returns:
nx.DiGraph: DAG with new edges added.
"""
-
+ rng = self.random_state
+ if rng is None:
+ rng = np.random.default_rng()
arrow_head_candidates = get_arrow_head_candidates_from_graph(
graph=self.graph, node_attributes_to_filter=NodeAttributes.ALLOW_IN_EDGES
)
@@ -460,9 +471,7 @@ def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph:
### choose edges uniformly according to sparsity parameter
chosen_edges = [
potential_edges[i]
- for i in self.random_state.choice(
- a=len(potential_edges), size=num_choices, replace=False
- )
+ for i in rng.choice(a=len(potential_edges), size=num_choices, replace=False)
]
self.graph.add_edges_from(chosen_edges)
@@ -474,24 +483,33 @@ def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph:
return self.graph
def __repr__(self):
+ """Repr method.
+
+ Returns:
+ _type_: _description_
+ """
return f"ProcessCell(name={self.name})"
def __str__(self):
+ """Str method.
+
+ Returns:
+ _type_: _description_
+ """
cell_description = {
"Cell Name: ": self.name,
"Description:": self.description if self.description else "n.a.",
"Modules:": self.no_of_modules,
"Nodes: ": self.num_nodes,
}
- s = str()
+ s = ""
for info, info_text in cell_description.items():
s += f"{info:<14}{info_text:>5}\n"
return s
def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]):
- """Check whether all starting point nodes
- (first value in edge tuple) are allowed.
+ """Check whether all starting point nodes (first value in edge tuple) are allowed.
Args:
m1 (str): Module1
@@ -504,8 +522,8 @@ def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]):
"""
source_nodes = set([e[0] for e in edges])
target_nodes = set([e[1] for e in edges])
- m1_nodes = set(self.modules.get(m1).nodes())
- m2_nodes = set(self.modules.get(m2).nodes())
+ m1_nodes = set(self.modules.get(m1).nodes()) # type: ignore
+ m2_nodes = set(self.modules.get(m2).nodes()) # type: ignore
if not source_nodes.issubset(m1_nodes):
raise ValueError(f"source nodes: {source_nodes} not include in {m1}s nodes: {m1_nodes}")
@@ -514,6 +532,7 @@ def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]):
def next_module_prefix(self) -> str:
"""Return the next module prefix, e.g.
+
if there are already 3 modules connected to the cell,
will return module_prefix4
@@ -524,6 +543,11 @@ def next_module_prefix(self) -> str:
@property
def module_prefix(self) -> str:
+ """Module prefix.
+
+ Returns:
+ str: _description_
+ """
return self.__module_prefix
@module_prefix.setter
@@ -535,12 +559,19 @@ def module_prefix(self, module_prefix: str):
@property
def no_of_modules(self) -> int:
- return len(self.modules)
+ """Number of modules in the cell.
- def get_nodes_by_attribute(self, attr_name: str, submodule: str = None) -> list:
- pass
+ Returns:
+ int: _description_
+ """
+ return len(self.modules)
def get_available_attributes(self):
+ """Get available attributes of the nodes in the graph.
+
+ Returns:
+ _type_: _description_
+ """
available_attributes = set()
for node_tuple in self.graph.nodes(data=True):
for attribute_name in node_tuple[1].keys():
@@ -549,13 +580,20 @@ def get_available_attributes(self):
return list(available_attributes)
def to_cpdag(self) -> PDAG:
+ """To CPDAG conversion.
+
+ Returns:
+ PDAG: _description_
+ """
return dag2cpdag(dag=self.graph)
def show(
self,
meta_desc: str = "",
):
- """Plots the cell graph by giving extra weight to nodes
+ """Plots the cell graph.
+
+ by giving extra weight to nodes
with high in- and out-degree.
Args:
@@ -583,11 +621,12 @@ def show(
vmin=-0.2,
vmax=1,
node_color=[
- (d + 10) / (max_in_degree + 10) for _, d in self.graph.in_degree(self.nodes)
+ (d + 10) / (max_in_degree + 10)
+ for _, d in self.graph.in_degree(self.nodes) # type: ignore
],
node_size=[
500 * (d + 1) / (max_out_degree + 1) for _, d in self.graph.out_degree(self.nodes)
- ],
+ ], # type: ignore
)
nx.draw_networkx_edges(
@@ -620,15 +659,22 @@ def _plot_cellgraph(
with_box=True,
meta_desc="",
):
- """Plots the cell graph by giving extra weight to nodes
+ """Plots the cell graph.
+
+ by giving extra weight to nodes
with high in- and out-degree.
Args:
- with_edges (bool, optional): Defaults to True.
- with_box (bool, optional): Defaults to True.
- meta_desc (str, optional): Defaults to "".
- center (_type_, optional): Defaults to np.array([0, 0]).
- fig_size (tuple, optional): Defaults to (2, 8).
+ ax (_type_): _description_
+ node_color (_type_): _description_
+ node_size (_type_): _description_
+ center (_type_, optional): _description_. Defaults to np.array([0, 0]).
+ with_edges (bool, optional): _description_. Defaults to True.
+ with_box (bool, optional): _description_. Defaults to True.
+ meta_desc (str, optional): _description_. Defaults to "".
+
+ Returns:
+ _type_: _description_
"""
cmap = plt.get_cmap("cividis")
@@ -674,7 +720,7 @@ def _plot_cellgraph(
PatchCollection(
[
FancyBboxPatch(
- center - [2, 1],
+ center - [2, 1], # type: ignore
4,
2.6,
boxstyle=BoxStyle("Round", pad=0.02),
@@ -695,7 +741,8 @@ def choose_edges_from_cells_randomly(
probability: float,
rng: np.random.Generator,
) -> list[tuple[str, str]]:
- """
+ """Choose cells randomly.
+
From two given cells (graphs), we take the cartesian product (end up with
from_cell.number_of_nodes x to_cell.number_of_nodes possible edges (node tuples).
@@ -708,12 +755,13 @@ def choose_edges_from_cells_randomly(
from_cell: ProcessCell from where we want the edges
to_cell: ProcessCell to where we want the edges
probability: between 0 and 1
+ rng (np.random.Generator): Random number generator to use.
Returns:
list[tuple[str, str]]: Chosen edges.
"""
-
- assert 0 <= probability <= 1.0
+ ONE = 1.0
+ assert 0 <= probability <= ONE
arrow_tail_candidates = list(from_cell.graph.nodes)
arrow_head_candidates = get_arrow_head_candidates_from_graph(graph=to_cell.graph)
@@ -736,6 +784,7 @@ def get_arrow_head_candidates_from_graph(
graph: nx.DiGraph, node_attributes_to_filter: str = NodeAttributes.ALLOW_IN_EDGES
) -> list[str]:
"""Returns all arrow head (nodes where an arrow points to) nodes as list of candidates.
+
To later build a list of tuples of potential edges.
Args:
@@ -791,6 +840,27 @@ class ProductionLineGraph:
"""
def __init__(self):
+ """Inits ProductionLineGraph.
+
+ Raises:
+ AssertionError: _description_
+ TypeError: _description_
+ AssertionError: _description_
+ AssertionError: _description_
+ AssertionError: _description_
+ ValueError: _description_
+ AssertionError: _description_
+ AssertionError: _description_
+ ValueError: _description_
+ ValueError: _description_
+ ValueError: _description_
+ TypeError: _description_
+ AssertionError: _description_
+ AttributeError: _description_
+
+ Returns:
+ _type_: _description_
+ """
self._random_state = np.random.default_rng(seed=2023)
self.cells: dict[str, ProcessCell] = dict()
self.cell_prefix = "C"
@@ -803,6 +873,11 @@ def __init__(self):
@property
def random_state(self):
+ """Random state.
+
+ Returns:
+ _type_: _description_
+ """
return self._random_state
@random_state.setter
@@ -816,8 +891,7 @@ def __init_mutilated_dag(self):
@property
def graph(self) -> nx.DiGraph:
- """
- Returns a nx.DiGraph object of the actual graph.
+ """Returns a nx.DiGraph object of the actual graph.
The graph is only built HERE, i.e. all ProcessCells exist standalone in self.cells,
with no connections between their nodes yet.
@@ -867,7 +941,7 @@ def edges(self) -> list[tuple]:
@property
def num_nodes(self) -> int:
- """Number of nodes in the graph
+ """Number of nodes in the graph.
Returns:
int
@@ -876,7 +950,7 @@ def num_nodes(self) -> int:
@property
def num_edges(self) -> int:
- """Number of edges in the graph
+ """Number of edges in the graph.
Returns:
int
@@ -885,7 +959,7 @@ def num_edges(self) -> int:
@property
def sparsity(self) -> float:
- """Sparsity of the graph
+ """Sparsity of the graph.
Returns:
float: in [0,1]
@@ -895,8 +969,7 @@ def sparsity(self) -> float:
@property
def ground_truth(self) -> pd.DataFrame:
- """Returns the current ground truth as
- pandas adjacency.
+ """Returns the current ground truth as pandas adjacency.
Returns:
pd.DataFrame: Adjacenccy matrix.
@@ -913,8 +986,7 @@ def _get_union_graph(self) -> nx.DiGraph:
@property
def within_adjacency(self) -> pd.DataFrame:
- """Returns adjacency matrix ignoring all
- between-cell edges.
+ """Returns adjacency matrix ignoring all between-cell edges.
Returns:
pd.DataFrame: adjacency matrix
@@ -924,8 +996,7 @@ def within_adjacency(self) -> pd.DataFrame:
@property
def between_adjacency(self) -> pd.DataFrame:
- """Returns adjacency matrix ignoring all
- within-cell edges.
+ """Returns adjacency matrix ignoring all within-cell edges.
Returns:
pd.DataFrame: adjacency matrix
@@ -936,6 +1007,7 @@ def between_adjacency(self) -> pd.DataFrame:
@property
def causal_order(self) -> list[str]:
"""Returns the causal order of the current graph.
+
Note that this order is in general not unique.
Returns:
@@ -955,6 +1027,11 @@ def parents(self, of_node: str) -> list[str]:
return list(self.graph.predecessors(of_node))
def to_cpdag(self) -> PDAG:
+ """Convert to CPDAG.
+
+ Returns:
+ PDAG: _description_
+ """
return dag2cpdag(dag=self.graph)
def get_nodes_of_station(self, station_name: str) -> list:
@@ -990,7 +1067,7 @@ def __add_cell(self, cell: ProcessCell) -> ProcessCell:
raise ValueError(f"A cell with name: {cell.name} is already in the Production Line.")
- def new_cell(self, name: str = None, is_eol: bool = False) -> ProcessCell:
+ def new_cell(self, name: str | None = None, is_eol: bool = False) -> ProcessCell:
"""Add a new cell to the production line.
If no name is given, cell name is given by counting available cells + 1
@@ -1009,7 +1086,7 @@ def new_cell(self, name: str = None, is_eol: bool = False) -> ProcessCell:
actual_no_of_cells = len(self.cells.values())
c = ProcessCell(name=f"{self.cell_prefix}{actual_no_of_cells}")
- c.random_state = self.random_state
+ c.random_state = self.random_state # type: ignore
c.is_eol = is_eol
self.__add_cell(cell=c)
@@ -1048,7 +1125,7 @@ def connect_cells(
rng=self.random_state,
)
- prob_it += 1 # FIXME: a bit ugly and hard to read
+ prob_it += 1
self.cell_connector_edges.extend(chosen_edges)
if eol_cell := self.eol_cell:
@@ -1066,8 +1143,7 @@ def connect_cells(
self.cell_connector_edges.extend(chosen_eol_edges)
def copy(self) -> ProductionLineGraph:
- """Makes a full copy of the current
- ProductionLineGraph object
+ """Makes a full copy of the current ProductionLineGraph object.
Returns:
ProductionLineGraph: copyied object.
@@ -1088,14 +1164,18 @@ def copy(self) -> ProductionLineGraph:
return copy_graph
def connect_across_cells_manually(self, edges: list[tuple]):
- """Add edges manually across cells. You need to give the full name
+ """Add edges manually across cells.
+
+ You need to give the full name
Args:
edges (list[tuple]): list of edges to add
"""
self.cell_connector_edges.extend(edges)
def intervene_on(self, nodes_values: dict[str, RandomSymbol | float]):
- """Specify hard or soft intervention. If you want to intervene
+ """Specify hard or soft intervention.
+
+ If you want to intervene
upon more than one node provide a list of nodes to intervene on
and a list of corresponding values to set these nodes to.
(see example). The mutilated dag will automatically be
@@ -1126,15 +1206,15 @@ def intervene_on(self, nodes_values: dict[str, RandomSymbol | float]):
mutilated_dag.remove_edges_from(edges_to_remove)
drf_replace[node] = value
- self.mutilated_dags[
- f"do({list(nodes_values.keys())})"
- ] = mutilated_dag # specifiying the same set twice will override
+ self.mutilated_dags[f"do({list(nodes_values.keys())})"] = (
+ mutilated_dag # specifiying the same set twice will override
+ )
self.interventional_drf[f"do({list(nodes_values.keys())})"] = drf_replace
@property
def interventions(self) -> list:
- """Returns all interventions performed on the original graph
+ """Returns all interventions performed on the original graph.
Returns:
list: list of intervened upon nodes in do(x) notation.
@@ -1168,13 +1248,13 @@ def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
@classmethod
def get_ground_truth(cls) -> ProductionLineGraph:
- """Loads in the ground_truth as described in the paper:
+ """Loads in the ground_truth as described in the paper.
+
causalAssembly: Generating Realistic Production Data for
Benchmarking Causal Discovery
Returns:
ProductionLineGraph: ground_truth for cells and line.
"""
-
gt_response = requests.get(DATA_GROUND_TRUTH, timeout=5)
ground_truth = json.loads(gt_response.text)
@@ -1201,7 +1281,8 @@ def get_ground_truth(cls) -> ProductionLineGraph:
@classmethod
def get_data(cls) -> pd.DataFrame:
- """Load in semi-synthetic data as described in the paper:
+ """Load in semi-synthetic data as described in the paper.
+
causalAssembly: Generating Realistic Production Data for
Benchmarking Causal Discovery
Returns:
@@ -1211,7 +1292,9 @@ def get_data(cls) -> pd.DataFrame:
@classmethod
def from_nx(cls, g: nx.DiGraph, cell_mapper: dict[str, list]):
- """Convert nx.DiGraph to ProductionLineGraph. Requires a dict mapping
+ """Convert nx.DiGraph to ProductionLineGraph.
+
+ Requires a dict mapping
where keys are cell names and values correspond to nodes within these cells.
Args:
@@ -1240,7 +1323,7 @@ def from_nx(cls, g: nx.DiGraph, cell_mapper: dict[str, list]):
return pline
@classmethod
- def load_drf(cls, filename: str, location: str = None):
+ def load_drf(cls, filename: str, location: str | Path | None = None):
"""Loads a drf dict from a .pkl file into the workspace.
Args:
@@ -1262,7 +1345,19 @@ def load_drf(cls, filename: str, location: str = None):
return pickle_drf
@classmethod
- def load_pline_from_pickle(cls, filename: str, location: str = None):
+ def load_pline_from_pickle(cls, filename: str, location: str | Path | None = None):
+ """Load production line graph from a pickle file.
+
+ Args:
+ filename (str): _description_
+ location (str | Path | None, optional): _description_. Defaults to None.
+
+ Raises:
+ TypeError: _description_
+
+ Returns:
+ _type_: _description_
+ """
if not location:
location = Path().resolve()
@@ -1276,7 +1371,7 @@ def load_pline_from_pickle(cls, filename: str, location: str = None):
return pickle_line
- def save_drf(self, filename: str, location: str = None):
+ def save_drf(self, filename: str, location: str | Path | None = None):
"""Writes a drf dict to file. Please provide the .pkl suffix!
Args:
@@ -1284,7 +1379,6 @@ def save_drf(self, filename: str, location: str = None):
location (str, optional): path to file in case it's not located in
the current working directory. Defaults to None.
"""
-
if not location:
location = Path().resolve()
@@ -1328,7 +1422,7 @@ def sample_from_interventional_drf(
)
def hidden_nodes(self) -> list:
- """Returns list of nodes marked as hidden
+ """Returns list of nodes marked as hidden.
Returns:
list: of hidden nodes
@@ -1340,16 +1434,19 @@ def hidden_nodes(self) -> list:
]
def visible_nodes(self):
+ """All visible nodes in the graph.
+
+ Returns:
+ _type_: _description_
+ """
return [node for node in self.nodes if node not in self.hidden_nodes()]
@property
def eol_cell(self) -> ProcessCell | None:
- """
+ """Returns ProcessCell.
- Returns ProcessCell: the EOL cell
+ the EOL cell
(if any single cell has attr .is_eol = True), otherwise returns None
- -------
-
"""
for cell in self.cells.values():
if cell.is_eol:
@@ -1358,6 +1455,7 @@ def eol_cell(self) -> ProcessCell | None:
@property
def ground_truth_visible(self) -> pd.DataFrame:
"""Generates a ground truth graph in form of a pandas adjacency matrix.
+
Row and column names correspond to visible.
The following integers can occur:
@@ -1368,7 +1466,6 @@ def ground_truth_visible(self) -> pd.DataFrame:
Returns:
pd.DataFrame: amat with values in {0,1,2}.
"""
-
if len(self.hidden_nodes()) == 0:
return self.ground_truth
else:
@@ -1385,8 +1482,7 @@ def ground_truth_visible(self) -> pd.DataFrame:
# reverse = lambda tuples: tuples[::-1]
def reverse(tuples):
- """
- Simple function to reverse tuple order
+ """Simple function to reverse tuple order.
Args:
tuples (tuple): tuple to reverse order
@@ -1404,7 +1500,7 @@ def reverse(tuples):
return amat_visible
def show(self, meta_description: list | None = None, fig_size: tuple = (15, 8)):
- """Plot full assembly line
+ """Plot full assembly line.
Args:
meta_description (list | None, optional): Specify additional cell info.
@@ -1458,27 +1554,51 @@ def show(self, meta_description: list | None = None, fig_size: tuple = (15, 8)):
)
def __str__(self):
+ """String method for ProductionLineGraph.
+
+ Returns:
+ _type_: _description_
+ """
s = "ProductionLine\n\n"
for cell in self.cells:
s += f"{cell}\n"
return s
def __getattr__(self, attrname):
+ """Get a cell by its name.
+
+ Args:
+ attrname (_type_): _description_
+
+ Raises:
+ AttributeError: _description_
+
+ Returns:
+ _type_: _description_
+ """
if attrname not in self.cells.keys():
raise AttributeError(f"{attrname} is not a valid attribute (cell name?)")
return self.cells[attrname]
- # https://docs.python.org/3/library/pickle.html#pickle-protocol
- # TODO why is .cells enough, are the other member vars directly pickable?
def __getstate__(self):
+ """Get current state of the ProductionLineGraph.
+
+ Returns:
+ _type_: _description_
+ """
return (self.__dict__, self.cells)
def __setstate__(self, state):
+ """Set state of the ProductionLineGraph.
+
+ Args:
+ state (_type_): _description_
+ """
self.__dict__, self.cells = state
@classmethod
def via_cell_number(cls, n_cells: int, cell_prefix: str = "C"):
- """Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3
+ """Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3.
Will create empty C0, C1 and C2 as cells if no other cell_prefix is given.
@@ -1496,8 +1616,7 @@ def via_cell_number(cls, n_cells: int, cell_prefix: str = "C"):
return pl
def _pairs_with_hidden_mediators(self):
- """
- Return pairs of nodes with hidden mediators present.
+ """Return pairs of nodes with hidden mediators present.
Args:
graph (nx.DiGraph): DAG
@@ -1506,6 +1625,7 @@ def _pairs_with_hidden_mediators(self):
Returns:
list: list of tuples with pairs of nodes with hidden mediator
"""
+ TWO = 2
any_paths = []
visible = self.visible_nodes()
hidden_all = self.hidden_nodes()
@@ -1517,14 +1637,15 @@ def _pairs_with_hidden_mediators(self):
any_paths.append(path)
pairs_with_hidden_mediators = [
- (ls[0], ls[-1]) for ls in any_paths if np.all(np.isin(ls[1:-1], hidden)) and len(ls) > 2
+ (ls[0], ls[-1])
+ for ls in any_paths
+ if np.all(np.isin(ls[1:-1], hidden)) and len(ls) > TWO
]
return pairs_with_hidden_mediators
def _pairs_with_hidden_confounders(self) -> dict:
- """
- Returns node-pairs that have a common hidden confounder
+ """Returns node-pairs that have a common hidden confounder.
Returns:
dict: Dictionary with keys equal to tuples of node-pairs
diff --git a/causalAssembly/models_fcm.py b/causalAssembly/models_fcm.py
index 2fc890f..00002fc 100644
--- a/causalAssembly/models_fcm.py
+++ b/causalAssembly/models_fcm.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -12,6 +13,7 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
"""
+
from __future__ import annotations
import logging
@@ -52,6 +54,12 @@ class FCM:
"""
def __init__(self, name: str | None = None, seed: int = 2023):
+ """Inits the FCM class.
+
+ Args:
+ name (str | None, optional): Name. Defaults to None.
+ seed (int, optional): Seed. Defaults to 2023.
+ """
self.name = name
self._random_state = np.random.default_rng(seed=seed)
self.__init_dag()
@@ -76,6 +84,7 @@ def source_nodes(self) -> list:
@property
def causal_order(self) -> list[Symbol]:
"""Returns the causal order of the current graph.
+
Note that this order is in general not unique. To
ensure uniqueness, we additionally sort lexicograpically.
@@ -104,7 +113,7 @@ def edges(self) -> list[tuple]:
@property
def num_nodes(self) -> int:
- """Number of nodes in the graph
+ """Number of nodes in the graph.
Returns:
int
@@ -113,7 +122,7 @@ def num_nodes(self) -> int:
@property
def num_edges(self) -> int:
- """Number of edges in the graph
+ """Number of edges in the graph.
Returns:
int
@@ -122,7 +131,7 @@ def num_edges(self) -> int:
@property
def sparsity(self) -> float:
- """Sparsity of the graph
+ """Sparsity of the graph.
Returns:
float: in [0,1]
@@ -132,8 +141,7 @@ def sparsity(self) -> float:
@property
def ground_truth(self) -> pd.DataFrame:
- """Returns the current ground truth as
- pandas adjacency.
+ """Returns the current ground truth as pandas adjacency.
Returns:
pd.DataFrame: Adjacenccy matrix.
@@ -141,8 +149,8 @@ def ground_truth(self) -> pd.DataFrame:
return nx.to_pandas_adjacency(self.graph, weight=None)
@property
- def interventions(self) -> list:
- """Returns all interventions performed on the original graph
+ def interventions(self) -> list[str]:
+ """Returns all interventions performed on the original graph.
Returns:
list: list of intervened upon nodes in do(x) notation.
@@ -199,6 +207,7 @@ def parents_of(self, node: Symbol, which_graph: nx.DiGraph) -> list[Symbol]:
def causal_order_of(self, which_graph: nx.DiGraph) -> list[Symbol]:
"""Returns the causal order of the chosen graph.
+
Note that this order is in general not unique. To
ensure uniqueness, we additionally sort lexicograpically.
@@ -208,7 +217,9 @@ def causal_order_of(self, which_graph: nx.DiGraph) -> list[Symbol]:
return list(nx.lexicographical_topological_sort(which_graph, key=lambda x: str(x)))
def source_nodes_of(self, which_graph: nx.DiGraph) -> list:
- """Returns the source nodes of a chosen graph. This is mainly for
+ """Returns the source nodes of a chosen graph.
+
+ This is mainly for
choosing different mutilated DAGs.
Args:
@@ -224,8 +235,8 @@ def source_nodes_of(self, which_graph: nx.DiGraph) -> list:
]
def input_fcm(self, fcm: list[Eq]):
- """
- Automatically builds up DAG according to the FCM fed in.
+ """Automatically builds up DAG according to the FCM fed in.
+
Args:
fcm (list): list of sympy equations generated as:
```[python]
@@ -300,9 +311,12 @@ def sample(
snr: None | float = 1 / 2,
source_df: None | pd.DataFrame = None,
) -> pd.DataFrame:
- """Draw samples from the joint distribution that factorizes
- according to the DAG implied by the FCM fed in. To avoid
- unexpected/unintended behavior, avoid defining fully
+ r"""Sample from joint.
+
+ Draw samples from the joint distribution that factorizes
+ according to the DAG implied by the FCM fed in.
+
+ To avoid unexpected/unintended behavior, avoid defining fully
deterministic equation systems.
If parameters in noise terms are additive and left unevaluated,
they're set according to a chosen Signal-To-Noise (SNR) ratio.
@@ -349,8 +363,9 @@ def interventional_sample(
snr: None | float = 1 / 2,
source_df: None | pd.DataFrame = None,
) -> pd.DataFrame:
- """Draw samples from the interventional distribution that factorizes
- according to the mutilated DAG after performing one or multiple
+ r"""Draw samples from the interventional distribution.
+
+ that factorizes according to the mutilated DAG after performing one or multiple
interventions. Otherwise the method behaves similar to sampling from the
non-interventional joint distribution. By default samples are drawn from the
first intervention you performed. If you intervened upon more than one node,
@@ -409,7 +424,9 @@ def _sample(
snr: None | float = 1 / 2,
source_df: None | pd.DataFrame = None,
) -> pd.DataFrame:
- """Draw samples from the joint distribution that factorizes
+ r"""Draw samples from the joint distribution.
+
+ that factorizes
according to the DAG implied by the FCM fed in. To avoid
unexpected/unintended behavior, avoid defining fully
deterministic equation systems.
@@ -422,6 +439,7 @@ def _sample(
Args:
size (int): Number of samples to draw.
+ which_graph (nx.DiGraph): Which graph to sample from.
additive_gaussian_noise (bool, optional): _description_. Defaults to False.
snr (None | float, optional): Signal-to-noise ratio
\\( SNR = \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\).
@@ -439,7 +457,6 @@ def _sample(
pd.DataFrame: Data frame with rows of lenght size and columns equal to the
number of nodes in the graph.
"""
-
if source_df is not None and not self.__source_df_condition(source_df):
raise AssertionError("Names in source_df don't match nodenames in graph.")
@@ -512,11 +529,11 @@ def _sample(
+ str(fcm_expr)
+ " according to the given SNR."
)
- noise_var = df[str(order)].var() / snr
+ noise_var = df[str(order)].var() / snr # type: ignore
df[str(order)] = df[str(order)] + sympy_sample(
fcm_expr.atoms(RandomSymbol)
.pop()
- .subs(self.__unfree_symbol(fcm_expr), np.sqrt(noise_var)),
+ .subs(self.__unfree_symbol(fcm_expr), np.sqrt(noise_var)), # type: ignore
size=size,
seed=self._random_state,
)
@@ -532,9 +549,11 @@ def _sample(
)
else:
noise = symbols("noise")
- noise_var = df[str(order)].var() / snr
+ noise_var = df[str(order)].var() / snr # type: ignore
df[str(noise)] = self._random_state.normal(
- loc=0, scale=np.sqrt(noise_var), size=size
+ loc=0,
+ scale=np.sqrt(noise_var), # type: ignore
+ size=size,
)
fcm_expr = which_graph.nodes[order]["term"] + noise
df[str(order)] = self.__eval_expression(df=df, fcm_expr=fcm_expr)
@@ -552,7 +571,7 @@ def __unfree_symbol(self, fcm_expr) -> set[Symbol]:
}.pop()
def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame:
- """Eval given fcm_expression with the values in given dataframe
+ """Eval given fcm_expression with the values in given dataframe.
Args:
df (pd.DataFrame): Data frame.
@@ -561,7 +580,6 @@ def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame:
Returns:
pd.DataFrame: Data frame after eval.
"""
-
correct_order = list(ordered(fcm_expr.free_symbols)) # self.__return_ordered_args(fcm_expr)
cols = [str(col) for col in correct_order]
evaluator = lambdify(correct_order, fcm_expr, "scipy")
@@ -569,11 +587,11 @@ def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame:
return evaluator(*[df[col] for col in cols])
def __distribution_parameters_explicit(self, order: Symbol, which_graph: nx.DiGraph) -> bool:
- """Returns true if distribution parameters
- are given explicitly, not symbolically.
+ """Returns true if distribution parameters are given explicitly, not symbolically.
Args:
order (node): node in graph
+ which_graph (nx.DiGraph): which graph to choose.
Returns:
bool:
@@ -596,7 +614,9 @@ def __source_df_condition(self, source_df: pd.DataFrame) -> bool:
)
def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]):
- """Specify hard or soft intervention. If you want to intervene
+ """Specify hard or soft intervention.
+
+ If you want to intervene
upon more than one node provide a list of nodes to intervene on
and a list of corresponding values to set these nodes to.
(see example). The mutilated dag will automatically be
@@ -625,7 +645,6 @@ def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]):
```
"""
-
if not set(nodes_values.keys()).issubset(set(self.nodes)):
raise AssertionError(
"One or more nodes you want to intervene upon are not in the graph."
@@ -640,11 +659,11 @@ def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]):
mutilated_dag.remove_edges_from(edges_to_remove)
mutilated_dag.nodes[node]["term"] = intervention.rhs
- self.mutilated_dags[
- f"do({list(nodes_values.keys())})"
- ] = mutilated_dag # specifiying the same set twice will override
+ self.mutilated_dags[f"do({list(nodes_values.keys())})"] = (
+ mutilated_dag # specifiying the same set twice will override
+ )
- def show(self, header: str | None = None, with_nodenames: bool = True) -> plt:
+ def show(self, header: str | None = None, with_nodenames: bool = True):
"""Plots the current DAG.
Args:
@@ -659,10 +678,8 @@ def show(self, header: str | None = None, with_nodenames: bool = True) -> plt:
header = ""
return self._show(which_graph=self.graph, header=header, with_nodenames=with_nodenames)
- def show_mutilated_dag(
- self, which_intervention: str | int = 0, with_nodenames: bool = True
- ) -> plt:
- """Plot mutilated DAG
+ def show_mutilated_dag(self, which_intervention: str | int = 0, with_nodenames: bool = True):
+ """Plot mutilated DAG.
Args:
which_intervention (str | int, optional): Which interventional distribution
@@ -683,8 +700,12 @@ def show_mutilated_dag(
)
def _show(self, which_graph: nx.DiGraph, header: str, with_nodenames: bool):
- """Plots the graph by giving extra weight to nodes
- with high in- and out-degree.
+ """Plots the graph by giving extra weight to nodes with high in- and out-degree.
+
+ Args:
+ which_graph (nx.DiGraph): _description_
+ header (str): _description_
+ with_nodenames (bool): _description_
"""
cmap = plt.get_cmap("Blues")
fig, ax = plt.subplots()
@@ -712,10 +733,10 @@ def _show(self, which_graph: nx.DiGraph, header: str, with_nodenames: bool):
vmax=1,
node_color=[
(d + 10) / (max_in_degree + 10) for _, d in which_graph.in_degree(self.nodes)
- ],
+ ], # type: ignore
node_size=[
500 * (d + 1) / (max_out_degree + 1) for _, d in which_graph.out_degree(self.nodes)
- ],
+ ], # type: ignore
)
if with_nodenames:
diff --git a/causalAssembly/pdag.py b/causalAssembly/pdag.py
index b709b11..c7964e7 100644
--- a/causalAssembly/pdag.py
+++ b/causalAssembly/pdag.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -27,17 +28,32 @@
class PDAG:
- """
- Class for dealing with partially directed graph i.e.
- graphs that contain both directed and undirected edges.
+ """Class for dealing with partially directed graphs.
+
+ i.e., graphs that contain both directed and undirected edges.
"""
def __init__(
self,
- nodes: list | None = None,
- dir_edges: list[tuple] | None = None,
- undir_edges: list[tuple] | None = None,
+ nodes: list[str] | list[int] | set[str] | set[int] | None = None,
+ dir_edges: list[tuple[str, str]]
+ | list[tuple[int, int]]
+ | set[tuple[str, str]]
+ | set[tuple[int, int]]
+ | None = None,
+ undir_edges: list[tuple[str, str]]
+ | list[tuple[int, int]]
+ | set[tuple[str, str]]
+ | set[tuple[int, int]]
+ | None = None,
):
+ """Inits the PDAG class.
+
+ Args:
+ nodes (list | None, optional): _description_. Defaults to None.
+ dir_edges (list[tuple] | None, optional): _description_. Defaults to None.
+ undir_edges (list[tuple] | None, optional): _description_. Defaults to None.
+ """
if nodes is None:
nodes = []
if dir_edges is None:
@@ -80,7 +96,7 @@ def _add_undir_edge(self, i, j):
self._undirected_neighbors[i].add(j)
self._undirected_neighbors[j].add(i)
- def children(self, node: str) -> set:
+ def children(self, node: str | int) -> set:
"""Gives all children of node `node`.
Args:
@@ -94,7 +110,7 @@ def children(self, node: str) -> set:
else:
return set()
- def parents(self, node: str) -> set:
+ def parents(self, node: str | int) -> set:
"""Gives all parents of node `node`.
Args:
@@ -108,7 +124,7 @@ def parents(self, node: str) -> set:
else:
return set()
- def neighbors(self, node: str) -> set:
+ def neighbors(self, node: str | int) -> set:
"""Gives all neighbors of node `node`.
Args:
@@ -122,9 +138,8 @@ def neighbors(self, node: str) -> set:
else:
return set()
- def undir_neighbors(self, node: str) -> set:
- """Gives all undirected neighbors
- of node `node`.
+ def undir_neighbors(self, node: str | int) -> set:
+ """Gives all undirected neighbors of node `node`.
Args:
node (str): node in current PDAG.
@@ -138,8 +153,7 @@ def undir_neighbors(self, node: str) -> set:
return set()
def is_adjacent(self, i: str, j: str) -> bool:
- """Return True if the graph contains an directed
- or undirected edge between i and j.
+ """Return True if the graph contains an directed or undirected edge between i and j.
Args:
i (str): node i.
@@ -156,9 +170,7 @@ def is_adjacent(self, i: str, j: str) -> bool:
)
def is_clique(self, potential_clique: set) -> bool:
- """
- Check every pair of node X potential_clique is adjacent.
- """
+ """Check every pair of node X potential_clique is adjacent."""
return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))
@classmethod
@@ -172,7 +184,7 @@ def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> PDAG:
PDAG
"""
assert pd_amat.shape[0] == pd_amat.shape[1]
- nodes = pd_amat.columns
+ nodes = list(pd_amat.columns)
all_connections = []
start, end = np.where(pd_amat != 0)
@@ -188,7 +200,7 @@ def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> PDAG:
return PDAG(nodes=nodes, dir_edges=dir_edges, undir_edges=undir_edges)
def remove_edge(self, i: str, j: str):
- """Removes edge in question
+ """Removes edge in question.
Args:
i (str): tail
@@ -211,6 +223,7 @@ def remove_edge(self, i: str, j: str):
def undir_to_dir_edge(self, tail: str, head: str):
"""Takes a undirected edge and turns it into a directed one.
+
tail indicates the starting node of the edge and head the end node, i.e.
tail -> head.
@@ -236,12 +249,12 @@ def undir_to_dir_edge(self, tail: str, head: str):
self._add_dir_edge(i=tail, j=head)
def remove_node(self, node):
- """Remove a node from the graph"""
+ """Remove a node from the graph."""
self._nodes.remove(node)
- self._dir_edges = {(i, j) for i, j in self._dir_edges if i != node and j != node}
+ self._dir_edges = {(i, j) for i, j in self._dir_edges if node not in (i, j)}
- self._undir_edges = {(i, j) for i, j in self._undir_edges if i != node and j != node}
+ self._undir_edges = {(i, j) for i, j in self._undir_edges if node not in (i, j)}
for child in self._children[node]:
self._parents[child].remove(node)
@@ -261,8 +274,7 @@ def remove_node(self, node):
self._undirected_neighbors.pop(node, "I was never here")
def to_dag(self) -> nx.DiGraph:
- """
- Algorithm as described in Chickering (2002):
+ r"""Algorithm as described in Chickering (2002).
1. From PDAG P create DAG G containing all directed edges from P
2. Repeat the following: Select node v in P s.t.
@@ -278,7 +290,6 @@ def to_dag(self) -> nx.DiGraph:
Returns:
nx.DiGraph: DAG that belongs to the MEC implied by the PDAG
"""
-
pdag = self.copy()
dag = nx.DiGraph()
@@ -308,7 +319,7 @@ def to_dag(self) -> nx.DiGraph:
for edge in pdag.undir_edges:
if node in edge:
incident_node = set(edge) - {node}
- dag.add_edge(*incident_node, node)
+ dag.add_edge(*incident_node, node) # type: ignore
pdag.remove_node(node)
break
@@ -324,7 +335,9 @@ def to_dag(self) -> nx.DiGraph:
@property
def adjacency_matrix(self) -> pd.DataFrame:
- """Returns adjacency matrix where the i,jth
+ """Returns adjacency matrix.
+
+ The i,jth
entry being one indicates that there is an edge
from i to j. A zero indicates that there is no edge.
@@ -343,7 +356,9 @@ def adjacency_matrix(self) -> pd.DataFrame:
return amat
def _amat_to_dag(self) -> pd.DataFrame:
- """Transform the adjacency matrix of an PDAG to the adjacency
+ """Adjacency matrix to random DAG.
+
+ Transform the adjacency matrix of an PDAG to the adjacency
matrix of a SOME DAG in the Markov equivalence class.
Returns:
@@ -375,7 +390,7 @@ def _amat_to_dag(self) -> pd.DataFrame:
)
def vstructs(self) -> set:
- """Retrieve v-structures
+ """Retrieve v-structures.
Returns:
set: set of all v-structures
@@ -389,8 +404,8 @@ def vstructs(self) -> set:
return vstructures
def copy(self):
- """Return a copy of the graph"""
- return PDAG(nodes=self._nodes, dir_edges=self._dir_edges, undir_edges=self._undir_edges)
+ """Return a copy of the graph."""
+ return PDAG(nodes=self._nodes, dir_edges=self._dir_edges, undir_edges=self._undir_edges) # type: ignore
def show(self):
"""Plot PDAG."""
@@ -414,8 +429,8 @@ def to_networkx(self) -> nx.MultiDiGraph:
return nx_pdag
def _meek_mec_enumeration(self, pdag: PDAG, dag_list: list):
- """Recursion algorithm which recursively applies the
- following steps:
+ """Recursion algorithm which recursively applies the following steps.
+
1. Orient the first undirected edge found.
2. Apply Meek rules.
3. Recurse with each direction of the oriented edge.
@@ -455,8 +470,8 @@ def _meek_mec_enumeration(self, pdag: PDAG, dag_list: list):
self._meek_mec_enumeration(pdag=g_copy, dag_list=dag_list)
def to_allDAGs(self) -> list[nx.DiGraph]:
- """Recursion algorithm which recursively applies the
- following steps:
+ """Recursion algorithm which recursively applies the following steps.
+
1. Orient the first undirected edge found.
2. Apply Meek rules.
3. Recurse with each direction of the oriented edge.
@@ -473,8 +488,7 @@ def to_allDAGs(self) -> list[nx.DiGraph]:
# use Meek's cpdag2alldag
def _apply_meek_rules(self, G: PDAG) -> PDAG:
- """Apply all four Meek rules to a
- PDAG turning it into a CPDAG.
+ """Apply all four Meek rules to a PDAG turning it into a CPDAG.
Args:
G (PDAG): PDAG to complete
@@ -511,13 +525,13 @@ def to_random_dag(self) -> nx.DiGraph:
return nx.from_pandas_adjacency(to_dag_candidate.adjacency_matrix, create_using=nx.DiGraph)
@property
- def nodes(self) -> list:
+ def nodes(self) -> list[str] | list[int]:
"""Get all nods in current PDAG.
Returns:
list: list of nodes.
"""
- return sorted(list(self._nodes))
+ return sorted(list(self._nodes)) # type: ignore
@property
def nnodes(self) -> int:
@@ -530,8 +544,7 @@ def nnodes(self) -> int:
@property
def num_undir_edges(self) -> int:
- """Number of undirected edges
- in current PDAG.
+ """Number of undirected edges in current PDAG.
Returns:
int: Number of undirected edges
@@ -540,8 +553,7 @@ def num_undir_edges(self) -> int:
@property
def num_dir_edges(self) -> int:
- """Number of directed edges
- in current PDAG.
+ """Number of directed edges in current PDAG.
Returns:
int: Number of directed edges
@@ -550,8 +562,7 @@ def num_dir_edges(self) -> int:
@property
def num_adjacencies(self) -> int:
- """Number of adjacent nodes
- in current PDAG.
+ """Number of adjacent nodes in current PDAG.
Returns:
int: Number of adjacent nodes
@@ -560,8 +571,7 @@ def num_adjacencies(self) -> int:
@property
def undir_edges(self) -> list[tuple]:
- """Gives all undirected edges in
- current PDAG.
+ """Gives all undirected edges in current PDAG.
Returns:
list[tuple]: List of undirected edges.
@@ -570,8 +580,7 @@ def undir_edges(self) -> list[tuple]:
@property
def dir_edges(self) -> list[tuple]:
- """Gives all directed edges in
- current PDAG.
+ """Gives all directed edges in current PDAG.
Returns:
list[tuple]: List of directed edges.
@@ -598,7 +607,9 @@ def vstructs(dag: nx.DiGraph) -> set:
def rule_1(pdag: PDAG) -> PDAG:
- """Given the following pattern X -> Y - Z. Orient Y - Z to Y -> Z
+ """Meeks first rule.
+
+ Given the following pattern X -> Y - Z. Orient Y - Z to Y -> Z
if X and Z are non-adjacent (otherwise a new v-structure arises).
Args:
@@ -625,7 +636,9 @@ def rule_1(pdag: PDAG) -> PDAG:
def rule_2(pdag: PDAG) -> PDAG:
- """Given the following directed triple
+ """Meeks 2nd rule.
+
+ Given the following directed triple
X -> Y -> Z where X - Z are indeed adjacent.
Orient X - Z to X -> Z otherwise a cycle arises.
@@ -653,7 +666,9 @@ def rule_2(pdag: PDAG) -> PDAG:
def rule_3(pdag: PDAG) -> PDAG:
- """Orient X - Z to X -> Z, whenever there are two triples
+ """Meeks third rule.
+
+ Orient X - Z to X -> Z, whenever there are two triples
X - Y1 -> Z and X - Y2 -> Z such that Y1 and Y2 are non-adjacent.
Args:
@@ -662,6 +677,7 @@ def rule_3(pdag: PDAG) -> PDAG:
Returns:
PDAG: PDAG after application of rule.
"""
+ TWO = 2
copy_pdag = pdag.copy()
for edge in copy_pdag.undir_edges:
reverse_edge = edge[::-1]
@@ -670,7 +686,7 @@ def rule_3(pdag: PDAG) -> PDAG:
# if true that tail - node1 -> head and tail - node2 -> head
# while {node1 U node2} = 0 then orient tail -> head
orient = False
- if len(copy_pdag.undir_neighbors(tail)) >= 2:
+ if len(copy_pdag.undir_neighbors(tail)) >= TWO:
undir_n = copy_pdag.undir_neighbors(tail)
selection = [
(node1, node2)
@@ -688,7 +704,9 @@ def rule_3(pdag: PDAG) -> PDAG:
def rule_4(pdag: PDAG) -> PDAG:
- """Orient X - Y1 to X -> Y1, whenever there are
+ """Meeks 4th rule.
+
+ Orient X - Y1 to X -> Y1, whenever there are
two triples with X - Z and X - Y1 <- Z and X - Y2 -> Z
such that Y1 and Y2 are non-adjacent.
@@ -720,7 +738,7 @@ def rule_4(pdag: PDAG) -> PDAG:
def dag2cpdag(dag: nx.DiGraph) -> PDAG:
- """Convertes a DAG into its unique CPDAG
+ """Convertes a DAG into its unique CPDAG.
Args:
dag (nx.DiGraph): DAG the CPDAG corresponds to.
@@ -728,7 +746,7 @@ def dag2cpdag(dag: nx.DiGraph) -> PDAG:
Returns:
PDAG: unique CPDAG
"""
- copy_dag = dag.copy()
+ copy_dag: nx.DiGraph = dag.copy() # type: ignore
# Skeleton
skeleton = nx.to_pandas_adjacency(copy_dag.to_undirected())
# v-Structures
diff --git a/licenses.txt b/licenses.txt
new file mode 100644
index 0000000..94610af
--- /dev/null
+++ b/licenses.txt
@@ -0,0 +1,76 @@
+ Name Version License
+ ghp-import 2.1.0 Apache Software License
+ importlib_metadata 8.7.0 Apache Software License
+ importlib_resources 6.5.2 Apache Software License
+ requests 2.32.4 Apache Software License
+ tzdata 2025.2 Apache Software License
+ watchdog 6.0.0 Apache Software License
+ packaging 25.0 Apache Software License; BSD License
+ python-dateutil 2.9.0.post0 Apache Software License; BSD License
+ verspec 0.1.0 Apache Software License; BSD License
+ uv 0.8.4 Apache Software License; MIT License
+ coverage 7.10.2 Apache-2.0
+ Jinja2 3.1.6 BSD License
+ MarkupSafe 3.0.2 BSD License
+ Pygments 2.19.2 BSD License
+ babel 2.17.0 BSD License
+ colorama 0.4.6 BSD License
+ contourpy 1.3.2 BSD License
+ cycler 0.12.1 BSD License
+ idna 3.10 BSD License
+ kiwisolver 1.4.8 BSD License
+ mike 2.1.3 BSD License
+ mkdocs 1.6.1 BSD License
+ mpmath 1.3.0 BSD License
+ networkx 3.4.2 BSD License
+ nodeenv 1.9.1 BSD License
+ numpy 2.2.6 BSD License
+ pandas 2.3.1 BSD License
+ pip-tools 7.5.0 BSD License
+ pycparser 2.22 BSD License
+ scipy 1.15.3 BSD License
+ sympy 1.14.0 BSD License
+ rpy2-robjects 3.6.1 GNU General Public License v2 or later (GPLv2+)
+ rpy2-rinterface 3.6.2 GPL-2.0-or-later
+ fonttools 4.59.0 MIT
+ identify 2.6.12 MIT
+ pytest-cov 6.2.1 MIT
+ PyYAML 6.0.2 MIT License
+ backrefs 5.9 MIT License
+ cffi 1.17.1 MIT License
+ cfgv 3.4.0 MIT License
+ charset-normalizer 3.4.2 MIT License
+ exceptiongroup 1.3.0 MIT License
+ iniconfig 2.1.0 MIT License
+ mergedeep 1.3.4 MIT License
+ mkdocs-get-deps 0.2.0 MIT License
+ mkdocs-material 9.6.16 MIT License
+ mkdocs-material-extensions 1.3.1 MIT License
+ paginate 0.5.7 MIT License
+ platformdirs 4.3.8 MIT License
+ pluggy 1.6.0 MIT License
+ pre_commit 4.2.0 MIT License
+ pymdown-extensions 10.16.1 MIT License
+ pyparsing 3.2.3 MIT License
+ pyproject_hooks 1.2.0 MIT License
+ pytest 8.4.1 MIT License
+ pytz 2025.2 MIT License
+ ruff 0.12.7 MIT License
+ six 1.17.0 MIT License
+ tzlocal 5.3.1 MIT License
+ virtualenv 20.33.0 MIT License
+ distlib 0.4.0 Python Software Foundation License
+ matplotlib 3.10.5 Python Software Foundation License
+ filelock 3.18.0 The Unlicense (Unlicense)
+ Markdown 3.8.2 UNKNOWN
+ build 1.3.0 UNKNOWN
+ click 8.2.1 UNKNOWN
+ griffe 1.9.0 UNKNOWN
+ mkdocs-autorefs 1.4.2 UNKNOWN
+ mkdocstrings 0.30.0 UNKNOWN
+ mkdocstrings-python 1.16.12 UNKNOWN
+ pillow 11.3.0 UNKNOWN
+ pyyaml_env_tag 1.1 UNKNOWN
+ typing_extensions 4.14.1 UNKNOWN
+ urllib3 2.5.0 UNKNOWN
+ zipp 3.23.0 UNKNOWN
diff --git a/pyproject.toml b/pyproject.toml
index 98be8c5..8e7c1e6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,7 +22,7 @@ dependencies = [
"requests",
]
readme = "README.md"
-license = {file = "LICENSE"}
+license = { file = "LICENSE" }
[project.urls]
Homepage = "https://github.com/boschresearch/causalAssembly"
@@ -38,10 +38,13 @@ dev = [
"mkdocs",
"mkdocs-material",
"mkdocstrings[python]",
- "ruff",
+ "pip-licenses",
"pip-tools",
"pre-commit",
"pytest",
+ "pytest-cov",
+ "ruff",
+ "uv",
]
[tool.setuptools.packages.find]
@@ -54,35 +57,62 @@ version = { file = "VERSION" }
[tool.ruff]
-select = ["A", "E", "F", "I"]
-ignore = []
-
-fixable = ["A", "B", "C", "D", "E", "F", "I"]
-unfixable = []
-
-line-length = 100
-
-target-version = "py310"
-
exclude = [
- ".bzr",
- ".direnv",
- ".eggs",
- ".git",
- ".hg",
+ ".github",
".mypy_cache",
- ".nox",
- ".pants.d",
+ ".pytest_cache",
".ruff_cache",
- ".svn",
- ".tox",
".venv",
- "__pypackages__",
- "_build",
- "buck-out",
- "build",
- "dist",
- "node_modules",
"venv",
+ #"notebooks/",
+]
+
+extend-include = ["*.ipynb"]
+
+line-length = 100
+
+[tool.ruff.lint]
+select = [
+ "E", # pycodestyle
+ "F", # pyflakes
+ "UP", # pyupgrade
+ "D", # pydocstyle
+ "PL", # pylint
+ "TD", # flake8-todos
+ "C90", # McCabe
]
-per-file-ignores = {}
+
+ignore = []
+
+# Allow fix for all enabled rules (when `--fix`) is provided.
+fixable = ["ALL"]
+unfixable = []
+
+
+[tool.ruff.lint.pydocstyle]
+convention = "google"
+
+[tool.ruff.lint.mccabe]
+max-complexity = 27
+
+[tool.ruff.lint.pylint]
+max-args = 15
+max-branches = 30
+max-statements = 100
+
+[tool.pytest.ini_options]
+addopts = "--cov=gresit --cov-fail-under=60"
+testpaths = ["tests"]
+
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..a788f3f
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,33 @@
+# This file was autogenerated by uv via the following command:
+# uv pip compile pyproject.toml --output-file=requirements.txt --annotation-style=line
+certifi==2025.8.3 # via requests
+cffi==1.17.1 # via rpy2-rinterface
+charset-normalizer==3.4.2 # via requests
+contourpy==1.3.2 # via matplotlib
+cycler==0.12.1 # via matplotlib
+fonttools==4.59.0 # via matplotlib
+idna==3.10 # via requests
+jinja2==3.1.6 # via rpy2-robjects
+kiwisolver==1.4.8 # via matplotlib
+markupsafe==3.0.2 # via jinja2
+matplotlib==3.10.5 # via causalassembly (pyproject.toml)
+mpmath==1.3.0 # via sympy
+networkx==3.4.2 # via causalassembly (pyproject.toml)
+numpy==2.2.6 # via contourpy, matplotlib, pandas, scipy, causalassembly (pyproject.toml)
+packaging==25.0 # via matplotlib
+pandas==2.3.1 # via causalassembly (pyproject.toml)
+pillow==11.3.0 # via matplotlib
+pycparser==2.22 # via cffi
+pyparsing==3.2.3 # via matplotlib
+python-dateutil==2.9.0.post0 # via matplotlib, pandas
+pytz==2025.2 # via pandas
+requests==2.32.4 # via causalassembly (pyproject.toml)
+rpy2==3.6.2 # via causalassembly (pyproject.toml)
+rpy2-rinterface==3.6.2 # via rpy2, rpy2-robjects
+rpy2-robjects==3.6.1 # via rpy2
+scipy==1.15.3 # via causalassembly (pyproject.toml)
+six==1.17.0 # via python-dateutil
+sympy==1.14.0 # via causalassembly (pyproject.toml)
+tzdata==2025.2 # via pandas
+tzlocal==5.3.1 # via rpy2-robjects
+urllib3==2.5.0 # via requests
diff --git a/requirements_dev.txt b/requirements_dev.txt
index 3666155..3a26ad6 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -4,80 +4,90 @@
#
# pip-compile --allow-unsafe --annotation-style=line --extra=dev --no-emit-index-url --no-emit-trusted-host --output-file=requirements_dev.txt pyproject.toml
#
-babel==2.14.0 # via mkdocs-material
-build==1.0.3 # via pip-tools
-certifi==2024.2.2 # via requests
-cffi==1.16.0 # via rpy2
+babel==2.17.0 # via mkdocs-material
+backrefs==5.9 # via mkdocs-material
+build==1.3.0 # via pip-tools
+certifi==2025.8.3 # via requests
+cffi==1.17.1 # via rpy2-rinterface
cfgv==3.4.0 # via pre-commit
-charset-normalizer==3.3.2 # via requests
-click==8.1.7 # via mkdocs, mkdocstrings, pip-tools
+charset-normalizer==3.4.2 # via requests
+click==8.2.1 # via mkdocs, pip-tools
colorama==0.4.6 # via griffe, mkdocs-material
-contourpy==1.2.0 # via matplotlib
+contourpy==1.3.2 # via matplotlib
+coverage[toml]==7.10.2 # via coverage, pytest-cov
cycler==0.12.1 # via matplotlib
-distlib==0.3.8 # via virtualenv
-exceptiongroup==1.2.0 # via pytest
-filelock==3.13.1 # via virtualenv
-fonttools==4.48.1 # via matplotlib
+distlib==0.4.0 # via virtualenv
+exceptiongroup==1.3.0 # via pytest
+filelock==3.18.0 # via virtualenv
+fonttools==4.59.0 # via matplotlib
ghp-import==2.1.0 # via mkdocs
-griffe==0.40.1 # via mkdocstrings-python
-identify==2.5.34 # via pre-commit
-idna==3.6 # via requests
-importlib-metadata==7.0.1 # via mike
-importlib-resources==6.1.1 # via mike
-iniconfig==2.0.0 # via pytest
-jinja2==3.1.3 # via mike, mkdocs, mkdocs-material, mkdocstrings, rpy2
-kiwisolver==1.4.5 # via matplotlib
-markdown==3.5.2 # via mkdocs, mkdocs-autorefs, mkdocs-material, mkdocstrings, pymdown-extensions
-markupsafe==2.1.5 # via jinja2, mkdocs, mkdocstrings
-matplotlib==3.8.2 # via causalAssembly (pyproject.toml)
-mergedeep==1.3.4 # via mkdocs
-mike==2.0.0 # via causalAssembly (pyproject.toml)
-mkdocs==1.5.3 # via causalAssembly (pyproject.toml), mike, mkdocs-autorefs, mkdocs-material, mkdocstrings
-mkdocs-autorefs==0.5.0 # via mkdocstrings
-mkdocs-material==9.5.9 # via causalAssembly (pyproject.toml)
+griffe==1.9.0 # via mkdocstrings-python
+identify==2.6.12 # via pre-commit
+idna==3.10 # via requests
+importlib-metadata==8.7.0 # via mike
+importlib-resources==6.5.2 # via mike
+iniconfig==2.1.0 # via pytest
+jinja2==3.1.6 # via mike, mkdocs, mkdocs-material, mkdocstrings, rpy2-robjects
+kiwisolver==1.4.8 # via matplotlib
+markdown==3.8.2 # via mkdocs, mkdocs-autorefs, mkdocs-material, mkdocstrings, pymdown-extensions
+markupsafe==3.0.2 # via jinja2, mkdocs, mkdocs-autorefs, mkdocstrings
+matplotlib==3.10.5 # via causalAssembly (pyproject.toml)
+mergedeep==1.3.4 # via mkdocs, mkdocs-get-deps
+mike==2.1.3 # via causalAssembly (pyproject.toml)
+mkdocs==1.6.1 # via causalAssembly (pyproject.toml), mike, mkdocs-autorefs, mkdocs-material, mkdocstrings
+mkdocs-autorefs==1.4.2 # via mkdocstrings, mkdocstrings-python
+mkdocs-get-deps==0.2.0 # via mkdocs
+mkdocs-material==9.6.16 # via causalAssembly (pyproject.toml)
mkdocs-material-extensions==1.3.1 # via mkdocs-material
-mkdocstrings[python]==0.24.0 # via causalAssembly (pyproject.toml), mkdocstrings-python
-mkdocstrings-python==1.8.0 # via mkdocstrings
+mkdocstrings[python]==0.30.0 # via causalAssembly (pyproject.toml), mkdocstrings-python
+mkdocstrings-python==1.16.12 # via mkdocstrings
mpmath==1.3.0 # via sympy
-networkx==3.2.1 # via causalAssembly (pyproject.toml)
-nodeenv==1.8.0 # via pre-commit
-numpy==1.26.4 # via causalAssembly (pyproject.toml), contourpy, matplotlib, pandas, scipy
-packaging==23.2 # via build, matplotlib, mkdocs, pytest
-paginate==0.5.6 # via mkdocs-material
-pandas==2.2.0 # via causalAssembly (pyproject.toml)
+networkx==3.4.2 # via causalAssembly (pyproject.toml)
+nodeenv==1.9.1 # via pre-commit
+numpy==2.2.6 # via causalAssembly (pyproject.toml), contourpy, matplotlib, pandas, scipy
+packaging==25.0 # via build, matplotlib, mkdocs, pytest
+paginate==0.5.7 # via mkdocs-material
+pandas==2.3.1 # via causalAssembly (pyproject.toml)
pathspec==0.12.1 # via mkdocs
-pillow==10.2.0 # via matplotlib
-pip-tools==7.3.0 # via causalAssembly (pyproject.toml)
-platformdirs==4.2.0 # via mkdocs, mkdocstrings, virtualenv
-pluggy==1.4.0 # via pytest
-pre-commit==3.6.1 # via causalAssembly (pyproject.toml)
-pycparser==2.21 # via cffi
-pygments==2.17.2 # via mkdocs-material
-pymdown-extensions==10.7 # via mkdocs-material, mkdocstrings
-pyparsing==3.1.1 # via matplotlib, mike
-pyproject-hooks==1.0.0 # via build
-pytest==8.0.0 # via causalAssembly (pyproject.toml)
-python-dateutil==2.8.2 # via ghp-import, matplotlib, pandas
-pytz==2024.1 # via pandas
-pyyaml==6.0.1 # via mike, mkdocs, pre-commit, pymdown-extensions, pyyaml-env-tag
-pyyaml-env-tag==0.1 # via mkdocs
-regex==2023.12.25 # via mkdocs-material
-requests==2.31.0 # via causalAssembly (pyproject.toml), mkdocs-material
-rpy2==3.5.15 # via causalAssembly (pyproject.toml)
-ruff==0.2.1 # via causalAssembly (pyproject.toml)
-scipy==1.12.0 # via causalAssembly (pyproject.toml)
-six==1.16.0 # via python-dateutil
-sympy==1.12 # via causalAssembly (pyproject.toml)
-tomli==2.0.1 # via build, pip-tools, pyproject-hooks, pytest
-tzdata==2024.1 # via pandas
-tzlocal==5.2 # via rpy2
-urllib3==2.2.0 # via requests
+pillow==11.3.0 # via matplotlib
+pip-licenses==5.0.0 # via causalAssembly (pyproject.toml)
+pip-tools==7.5.0 # via causalAssembly (pyproject.toml)
+platformdirs==4.3.8 # via mkdocs-get-deps, virtualenv
+pluggy==1.6.0 # via pytest, pytest-cov
+pre-commit==4.2.0 # via causalAssembly (pyproject.toml)
+prettytable==3.16.0 # via pip-licenses
+pycparser==2.22 # via cffi
+pygments==2.19.2 # via mkdocs-material, pytest
+pymdown-extensions==10.16.1 # via mkdocs-material, mkdocstrings
+pyparsing==3.2.3 # via matplotlib, mike
+pyproject-hooks==1.2.0 # via build, pip-tools
+pytest==8.4.1 # via causalAssembly (pyproject.toml), pytest-cov
+pytest-cov==6.2.1 # via causalAssembly (pyproject.toml)
+python-dateutil==2.9.0.post0 # via ghp-import, matplotlib, pandas
+pytz==2025.2 # via pandas
+pyyaml==6.0.2 # via mike, mkdocs, mkdocs-get-deps, pre-commit, pymdown-extensions, pyyaml-env-tag
+pyyaml-env-tag==1.1 # via mike, mkdocs
+requests==2.32.4 # via causalAssembly (pyproject.toml), mkdocs-material
+rpy2==3.6.2 # via causalAssembly (pyproject.toml)
+rpy2-rinterface==3.6.2 # via rpy2, rpy2-robjects
+rpy2-robjects==3.6.1 # via rpy2
+ruff==0.12.7 # via causalAssembly (pyproject.toml)
+scipy==1.15.3 # via causalAssembly (pyproject.toml)
+six==1.17.0 # via python-dateutil
+sympy==1.14.0 # via causalAssembly (pyproject.toml)
+tomli==2.2.1 # via build, coverage, pip-licenses, pip-tools, pytest
+typing-extensions==4.14.1 # via exceptiongroup, mkdocstrings-python
+tzdata==2025.2 # via pandas
+tzlocal==5.3.1 # via rpy2-robjects
+urllib3==2.5.0 # via requests
+uv==0.8.4 # via causalAssembly (pyproject.toml)
verspec==0.1.0 # via mike
-virtualenv==20.25.0 # via pre-commit
-watchdog==4.0.0 # via mkdocs
-wheel==0.42.0 # via pip-tools
-zipp==3.17.0 # via importlib-metadata
+virtualenv==20.33.0 # via pre-commit
+watchdog==6.0.0 # via mkdocs
+wcwidth==0.2.13 # via prettytable
+wheel==0.45.1 # via pip-tools
+zipp==3.23.0 # via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
-pip==24.0 # via pip-tools
-setuptools==69.1.0 # via nodeenv, pip-tools
+pip==25.2 # via pip-tools
+setuptools==80.9.0 # via pip-tools
diff --git a/tests/test_dag.py b/tests/test_dag.py
index f5586de..c53f065 100644
--- a/tests/test_dag.py
+++ b/tests/test_dag.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -12,6 +13,7 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
"""
+
import networkx as nx
import numpy as np
import pandas as pd
@@ -22,29 +24,62 @@
class TestDAG:
+ """Test class for the DAG class.
+
+ Returns:
+ _type_: _description_
+ """
+
@pytest.fixture(scope="class")
def example_dag(self) -> DAG:
+ """Example dag.
+
+ Returns:
+ DAG: _description_
+ """
return DAG(nodes=["A", "B", "C"], edges=[("A", "B"), ("A", "C")])
def test_instance_is_created(self):
+ """Check whether an instance of DAG can be created with nodes only."""
dag = DAG(nodes=["A", "B", "C"])
assert isinstance(dag, DAG)
def test_edges(self, example_dag: DAG):
- assert example_dag.num_edges == 2
+ """Test edges of the DAG.
+
+ Args:
+ example_dag (DAG): _description_
+ """
+ TWO = 2
+ assert example_dag.num_edges == TWO
assert set(example_dag.edges) == {("A", "B"), ("A", "C")}
def test_children(self, example_dag: DAG):
+ """Test children of the DAG.
+
+ Args:
+ example_dag (DAG): _description_
+ """
assert set(example_dag.children(of_node="A")) == {"B", "C"}
assert example_dag.children(of_node="B") == []
assert example_dag.children(of_node="C") == []
def test_parents(self, example_dag: DAG):
+ """TEst parents of the DAG.
+
+ Args:
+ example_dag (DAG): _description_
+ """
assert example_dag.parents(of_node="A") == []
assert set(example_dag.parents(of_node="B")) == {"A"}
assert set(example_dag.parents(of_node="C")) == {"A"}
def test_from_pandas_adjacency(self, example_dag: DAG):
+ """Test import from pandas adjacency matrix.
+
+ Args:
+ example_dag (DAG): _description_
+ """
amat = pd.DataFrame(
[[0, 1, 1], [0, 0, 0], [0, 0, 0]],
columns=["A", "B", "C"],
@@ -55,6 +90,11 @@ def test_from_pandas_adjacency(self, example_dag: DAG):
assert set(from_pandas_pdag.edges) == set(example_dag.edges)
def test_remove_edge(self, example_dag: DAG):
+ """Test removing edges.
+
+ Args:
+ example_dag (DAG): _description_
+ """
assert ("A", "C") in example_dag.edges
example_dag.remove_edge("A", "C")
assert ("A", "C") not in example_dag.edges
@@ -62,34 +102,49 @@ def test_remove_edge(self, example_dag: DAG):
example_dag.remove_edge("B", "A")
def test_remove_node(self, example_dag: DAG):
+ """Test removing nodes.
+
+ Args:
+ example_dag (DAG): _description_
+ """
assert "C" in example_dag.nodes
example_dag.remove_node("C")
assert "C" not in example_dag.nodes
def test_to_cpdag(self):
+ """Test conversion to CPDAG."""
+ TWO = 2
dag = DAG()
dag.add_edges_from([("A", "B"), ("A", "C")])
cpdag = dag.to_cpdag()
assert isinstance(cpdag, PDAG)
- assert cpdag.num_undir_edges == 2
+ assert cpdag.num_undir_edges == TWO
assert cpdag.num_dir_edges == 0
def test_adjacency_matrix(self, example_dag: DAG):
+ """Test return of adjacency matrix.
+
+ Args:
+ example_dag (DAG): _description_
+ """
amat = example_dag.adjacency_matrix
assert amat.shape[0] == amat.shape[1] == example_dag.num_nodes
assert amat.sum().sum() == example_dag.num_edges
def test_to_networkx(self, example_dag: DAG):
+ """Test conversion to NetworkX graph."""
nxg = example_dag.to_networkx()
assert isinstance(nxg, nx.DiGraph)
assert set(nxg.edges) == set(example_dag.edges)
def test_from_networkx(self, example_dag: DAG):
+ """Test conversion from NetworkX graph to DAG."""
nxg = example_dag.to_networkx()
from_networkx_dag = DAG.from_nx(nxg)
assert set(from_networkx_dag.edges) == set(example_dag.edges)
def test_error_when_cyclic(self):
+ """Test that an error is raised when trying to create a cyclic DAG."""
dag = DAG()
with pytest.raises(ValueError):
dag.add_edges_from([("A", "C"), ("C", "D"), ("D", "A")])
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 3533340..6be09e1 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -12,6 +13,7 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
"""
+
import random
import string
@@ -24,8 +26,19 @@
class TestDAGmetrics:
+ """Test metrics class.
+
+ Returns:
+ _type_: _description_
+ """
+
@pytest.fixture(scope="class")
def gt(self):
+ """Set up ground truth adjacency matrix.
+
+ Returns:
+ _type_: _description_
+ """
names = list(string.ascii_lowercase)[0:5]
temp_np = np.array(
[
@@ -40,6 +53,11 @@ def gt(self):
@pytest.fixture(scope="class")
def est(self):
+ """Set up estimated adjacency matrix.
+
+ Returns:
+ _type_: _description_
+ """
names = list(string.ascii_lowercase)[0:5]
temp_np = np.array(
[
@@ -53,11 +71,23 @@ def est(self):
return pd.DataFrame(temp_np, columns=names, index=names)
def test_pd_input_works(self, gt, est):
+ """Test that the metrics class can be initialized with pandas DataFrames.
+
+ Args:
+ gt (_type_): _description_
+ est (_type_): _description_
+ """
met = DAGmetrics(truth=gt, est=est)
assert np.array_equal(met.truth, gt.to_numpy())
assert np.array_equal(met.est, est.to_numpy())
def test_nx_input_works(self, gt, est):
+ """Test that the metrics class can be initialized with networkx graphs.
+
+ Args:
+ gt (_type_): _description_
+ est (_type_): _description_
+ """
gt_nx = nx.from_pandas_adjacency(gt, create_using=nx.DiGraph)
est_nx = nx.from_pandas_adjacency(est, create_using=nx.DiGraph)
met = DAGmetrics(truth=gt_nx, est=est_nx)
@@ -65,6 +95,12 @@ def test_nx_input_works(self, gt, est):
assert np.array_equal(met.est, est.to_numpy())
def test_pd_change_order(self, gt, est):
+ """Test that the metrics class can handle different node orders in input DataFrames.
+
+ Args:
+ gt (_type_): _description_
+ est (_type_): _description_
+ """
nodelist = list(string.ascii_lowercase)[0:5]
random.shuffle(nodelist)
met = DAGmetrics(truth=gt, est=est, nodelist=nodelist)
@@ -72,6 +108,12 @@ def test_pd_change_order(self, gt, est):
assert np.array_equal(met.est, est.reindex(nodelist)[nodelist].to_numpy())
def test_nx_change_order(self, gt, est):
+ """Test that the metrics class can handle different node orders in networkx graphs.
+
+ Args:
+ gt (_type_): _description_
+ est (_type_): _description_
+ """
nodelist = list(string.ascii_lowercase)[0:5]
random.shuffle(nodelist)
@@ -83,11 +125,18 @@ def test_nx_change_order(self, gt, est):
assert np.array_equal(met.est, est.reindex(nodelist)[nodelist].to_numpy())
def test_metrics_values(self, gt, est):
+ """Test the metrics values.
+
+ Args:
+ gt (_type_): _description_
+ est (_type_): _description_
+ """
+ THREE = 3
met = DAGmetrics(truth=gt, est=est)
- met.collect_metrics()
+ met = met.collect_metrics()
- assert met.metrics["shd"] == 3
- assert met.metrics["gscore"] >= 0 and met.metrics["gscore"] <= 1
- assert met.metrics["f1"] >= 0 and met.metrics["f1"] <= 1
- assert met.metrics["recall"] >= 0 and met.metrics["recall"] <= 1
- assert met.metrics["precision"] >= 0 and met.metrics["precision"] <= 1
+ assert met["shd"] == THREE
+ assert met["gscore"] >= 0 and met["gscore"] <= 1
+ assert met["f1"] >= 0 and met["f1"] <= 1
+ assert met["recall"] >= 0 and met["recall"] <= 1
+ assert met["precision"] >= 0 and met["precision"] <= 1
diff --git a/tests/test_models_dag.py b/tests/test_models_dag.py
index 9d3c7c2..73bc020 100644
--- a/tests/test_models_dag.py
+++ b/tests/test_models_dag.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -12,6 +13,7 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
"""
+
import math
import os
import pickle
@@ -20,29 +22,46 @@
import numpy as np
import pandas as pd
import pytest
-from sympy.stats import Beta
from causalAssembly.models_dag import NodeAttributes, ProcessCell, ProductionLineGraph
class TestProcessCell:
+ """Test process Cell.
+
+ Returns:
+ _type_: _description_
+ """
+
@pytest.fixture(scope="class")
def cell(self):
+ """Set up a ProcessCell instance.
+
+ Returns:
+ _type_: _description_
+ """
c = ProcessCell(name="PYTEST")
return c
@pytest.fixture(scope="class")
def module(self):
+ """Set up a module for testing.
+
+ Returns:
+ _type_: _description_
+ """
m = nx.DiGraph()
m.add_nodes_from(["A", "B", "C"])
m.add_edges_from([("A", "B"), ("B", "C")])
return m
def test_instance_is_created(self):
+ """Test whether an instance of ProcessCell can be created with a name."""
cell = ProcessCell(name="PYTEST")
assert isinstance(cell, ProcessCell)
def test_next_module_prefix_works(self):
+ """Test whether the next module prefix is generated correctly."""
# Arrange
cell = ProcessCell(name="PYTEST")
@@ -55,10 +74,18 @@ def test_next_module_prefix_works(self):
assert cell.next_module_prefix() == "ABC1"
def test_module_prefix_setter_works(self, cell):
+ """TEst that the module prefix can be set correctly.
+
+ Args:
+ cell (_type_): _description_
+ """
with pytest.raises(ValueError):
cell.module_prefix = 1
def test_add_module_works(self, module):
+ """Test that a module can be added to the ProcessCell."""
+ TWO = 2
+ SIX = 6
# Arrange
cell = ProcessCell(name="C01")
cell.module_prefix = "M"
@@ -68,11 +95,16 @@ def test_add_module_works(self, module):
cell.add_module(module)
# Assert
- assert len(cell.modules) == 2
- assert len(cell.graph.nodes()) == 6
+ assert len(cell.modules) == TWO
+ assert len(cell.graph.nodes()) == SIX
assert cell.next_module_prefix() == "M3"
def test_connect_by_module_works(self, module):
+ """Test that modules can be connected by edges.
+
+ Args:
+ module (_type_): _description_
+ """
# Arrange
cell = ProcessCell(name="PyTestCell")
m1 = cell.add_module(graph=module)
@@ -92,6 +124,12 @@ def test_connect_by_module_works(self, module):
ids=["source node invalid", "target node invalid"],
)
def test_connect_by_module_fails_with_wrong_node_name(self, module, edges):
+ """Test that module connections fails when node names are invalid.
+
+ Args:
+ module (_type_): _description_
+ edges (_type_): _description_
+ """
# Arrange
cell = ProcessCell(name="PyTestCell")
@@ -104,6 +142,11 @@ def test_connect_by_module_fails_with_wrong_node_name(self, module, edges):
cell.connect_by_module(m1=m1, m2=m2, edges=edges)
def test_node_property(self, module):
+ """Test properties of the nodes.
+
+ Args:
+ module (_type_): _description_
+ """
# Arrange
cell = ProcessCell(name="PyTest")
@@ -116,6 +159,11 @@ def test_node_property(self, module):
assert len(cell.nodes) == expected_no_of_nodes
def test_repr_is_working(self, module):
+ """Test repr is working.
+
+ Args:
+ module (_type_): _description_
+ """
# Arrange
cell = ProcessCell(name="PyTest")
@@ -131,6 +179,11 @@ def test_repr_is_working(self, module):
ids=["sparsity=0.0", "sparsity=0.1", "sparsity=1.0"],
)
def test_connect_by_random_edges(self, sparsity):
+ """Test whether connecting with random edges works.
+
+ Args:
+ sparsity (_type_): _description_
+ """
pline = ProductionLineGraph()
pline.new_cell(name="C1")
# Arrange
@@ -149,6 +202,7 @@ def test_connect_by_random_edges(self, sparsity):
assert len(c.graph.edges) == expected_edges
def test_connect_by_random_edges_fails_with_cyclic_graph(self):
+ """Test failure with cyclic graphs."""
pline = ProductionLineGraph()
pline.new_cell(name="C1")
# Arrange
@@ -166,6 +220,11 @@ def test_connect_by_random_edges_fails_with_cyclic_graph(self):
c.connect_by_random_edges()
def test_get_nodes_by_attribute(self, module):
+ """Test get nodes by attribute.
+
+ Args:
+ module (_type_): _description_
+ """
c = ProcessCell(name="C1")
c.add_module(graph=module)
@@ -174,15 +233,18 @@ def test_get_nodes_by_attribute(self, module):
assert NodeAttributes.ALLOW_IN_EDGES in available_attributes
def test_input_cellgraph_directly_works(self):
+ """Test whether cellgraph is inputted correctly."""
+ THREE = 3
toygraph = nx.DiGraph()
toygraph.add_edges_from([("a", "b"), ("a", "c"), ("b", "c")])
c = ProcessCell(name="toycell")
c.input_cellgraph_directly(toygraph)
- assert len(c.nodes) == 3
+ assert len(c.nodes) == THREE
assert c.nodes == ["toycell_a", "toycell_b", "toycell_c"]
def test_ground_truth_cell(self):
+ """Test ground truth."""
pline = ProductionLineGraph()
pline.new_cell(name="test")
pline.test.add_random_module()
@@ -195,11 +257,15 @@ def test_ground_truth_cell(self):
class TestProductionLineGraph:
+ """Test ProductionLineGraph."""
+
def test_instance_is_created(self):
+ """Test whether instance is created."""
p = ProductionLineGraph()
assert isinstance(p, ProductionLineGraph)
def test_getattr_works(self):
+ """Test getattr."""
# Arrange
station_name = "Station1"
p = ProductionLineGraph()
@@ -211,6 +277,7 @@ def test_getattr_works(self):
p.XXX
def test_str_representation(self):
+ """Test str."""
p = ProductionLineGraph()
p.new_cell(name="C1")
p.C1.add_random_module(n_nodes=10)
@@ -218,6 +285,7 @@ def test_str_representation(self):
assert isinstance(str(p), str)
def test_create_cell_works(self):
+ """Test cell creation."""
p = ProductionLineGraph()
c1 = p.new_cell()
@@ -225,12 +293,14 @@ def test_create_cell_works(self):
assert c1.name == "C0"
def test_create_cell_with_name_works(self):
+ """Test cell with name."""
p = ProductionLineGraph()
c1 = p.new_cell(name="PyTest")
assert c1.name == "PyTest"
def test_append_same_cell_twice_fails(self):
+ """Test failure."""
p = ProductionLineGraph()
p.new_cell(name="PyTest")
@@ -238,18 +308,21 @@ def test_append_same_cell_twice_fails(self):
p.new_cell(name="PyTest")
def test_instance_via_cell_number_works(self):
+ """Test instance via cell number."""
n_cells = 10
p = ProductionLineGraph.via_cell_number(n_cells=n_cells)
assert len(p.cells) == n_cells
def test_if_graph_exists(self):
+ """Test existence."""
n_cells = 10
p = ProductionLineGraph.via_cell_number(n_cells=n_cells)
assert isinstance(p.graph, nx.DiGraph)
def test_add_eol_cell(self):
+ """Test eol cell."""
p = ProductionLineGraph()
p.new_cell()
p.new_cell(is_eol=True)
@@ -257,6 +330,7 @@ def test_add_eol_cell(self):
assert isinstance(p.eol_cell, ProcessCell)
def test_add_eol_cell_twice_fails(self):
+ """Test eol twice fails."""
p = ProductionLineGraph()
p.new_cell(is_eol=True)
@@ -269,6 +343,7 @@ def test_add_eol_cell_twice_fails(self):
ids=["some edges", "zero edges", "all edges"],
)
def test_connect_cells_works_with_single_cell(self, n_nodes, forward_prob):
+ """Test connect."""
# Arrange
p = ProductionLineGraph()
p.new_cell(name="C1")
@@ -305,6 +380,12 @@ def test_connect_cells_works_with_single_cell(self, n_nodes, forward_prob):
],
)
def test_connect_cells_works_with_multiple_cells(self, n_nodes, forward_probs):
+ """Test connect with multiple cells.
+
+ Args:
+ n_nodes (_type_): _description_
+ forward_probs (_type_): _description_
+ """
# Arrange
p = ProductionLineGraph()
@@ -335,6 +416,12 @@ def test_connect_cells_works_with_multiple_cells(self, n_nodes, forward_probs):
ids=["some edges", "zero edges", "all edges"],
)
def test_connect_cells_works_with_eol_cell(self, n_nodes, forward_prob):
+ """Test connect with eol.
+
+ Args:
+ n_nodes (_type_): _description_
+ forward_prob (_type_): _description_
+ """
# Arrange
p = ProductionLineGraph()
@@ -357,6 +444,7 @@ def test_connect_cells_works_with_eol_cell(self, n_nodes, forward_prob):
assert len(p.graph.edges) == no_of_expected_edges
def test_connect_across_cells_manually(self):
+ """TEst manual connection."""
n_nodes = 10
prob = 0.1
p = ProductionLineGraph()
@@ -376,6 +464,7 @@ def test_connect_across_cells_manually(self):
assert len(p.graph.edges) == edges_in_C1 + edges_in_C2 + len(edgelist)
def test_acyclicity_error(self):
+ """Test acycicity."""
n_nodes = 5
prob = 0.1
p = ProductionLineGraph()
@@ -396,6 +485,7 @@ def test_acyclicity_error(self):
print(p.graph.edges)
def test_ground_truth_visible(self):
+ """Test ground truth visible."""
n_nodes = 5
prob = 0.1
p = ProductionLineGraph()
@@ -414,6 +504,8 @@ def test_ground_truth_visible(self):
assert len(p._pairs_with_hidden_mediators()) == 0
def test_ground_truth_hidden(self):
+ """TEst hidden gt."""
+ TWO = 2
edges1 = [(1, 2), (2, 3)]
edges2 = [(1, 3), (2, 3)]
edges3 = [(1, 2), (2, 3)]
@@ -444,10 +536,13 @@ def test_ground_truth_hidden(self):
# {(C2_M1_1, C2_M1_2): C1_M1_3}
assert p._pairs_with_hidden_confounders() == {("C2_M1_1", "C2_M1_2"): ["C1_M1_3"]}
- assert p.ground_truth_visible.loc[("C2_M1_1", "C2_M1_2")] == 2
- assert p.ground_truth_visible.loc[("C2_M1_2", "C2_M1_1")] == 2
+ assert p.ground_truth_visible.loc[("C2_M1_1", "C2_M1_2")] == TWO
+ assert p.ground_truth_visible.loc[("C2_M1_2", "C2_M1_1")] == TWO
def test_input_cellgraph_directly(self):
+ """Test input directly."""
+ SIX = 6
+ FOUR = 4
dag1 = nx.DiGraph([(0, 1), (1, 2)])
dag2 = nx.DiGraph([(3, 4), (3, 5)])
@@ -457,20 +552,24 @@ def test_input_cellgraph_directly(self):
testline.new_cell(name="Station2")
testline.Station2.input_cellgraph_directly(graph=dag2)
- assert testline.num_nodes == 6
- assert testline.num_edges == 4
+ assert testline.num_nodes == SIX
+ assert testline.num_edges == FOUR
assert testline.sparsity == pytest.approx(4 / math.comb(6, 2))
def test_drf_size(self):
+ """TEst drf size."""
testline = ProductionLineGraph()
assert not testline.drf
def test_drf_error(self):
+ """TEst drf error."""
testline = ProductionLineGraph()
with pytest.raises(ValueError):
testline.sample_from_drf()
def test_from_nx(self):
+ """Test from nx."""
+ TWO = 2
nx_graph = nx.DiGraph(
[("1", "2"), ("1", "3"), ("1", "4"), ("2", "5"), ("2", "6"), ("5", "6")]
)
@@ -478,7 +577,7 @@ def test_from_nx(self):
cell_mapper = {"cell1": ["1", "2", "3", "4"], "cell2": ["5", "6"]}
pline_from_nx = ProductionLineGraph.from_nx(g=nx_graph, cell_mapper=cell_mapper)
- assert len(pline_from_nx.cells) == 2
+ assert len(pline_from_nx.cells) == TWO
assert set(pline_from_nx.cell1.nodes) == {
"cell1_1",
"cell1_2",
@@ -490,18 +589,30 @@ def test_from_nx(self):
ProductionLineGraph.from_nx(pd_graph, cell_mapper=cell_mapper)
def test_save_and_load_drf(self, tmp_path_factory):
+ """Test save and load.
+
+ Args:
+ tmp_path_factory (_type_): _description_
+ """
basedir = tmp_path_factory.mktemp("data")
filename = "drf.pkl"
line1 = ProductionLineGraph()
- line1.drf = np.array([[1, 2, 3]])
+ line1.drf = np.array([[1, 2, 3]]) # type: ignore
line1.save_drf(filename=filename, location=basedir)
line2 = ProductionLineGraph()
line2.drf = ProductionLineGraph.load_drf(filename=filename, location=basedir)
- assert np.array_equal(line2.drf, np.array([[1, 2, 3]]))
+ assert np.array_equal(line2.drf, np.array([[1, 2, 3]])) # type: ignore
def test_pickleability(self, tmp_path):
+ """Test pickle.
+
+ Args:
+ tmp_path (_type_): _description_
+ """
+ SEVEN = 7
+ TEN = 10
# Arrange
filename_path = os.path.join(tmp_path, "pline.pkl")
pline = ProductionLineGraph()
@@ -523,10 +634,11 @@ def test_pickleability(self, tmp_path):
# Assert
assert new_edge in pline_reloaded.edges
- assert len(pline.Station1.nodes) == 7
- assert len(pline.Station2.nodes) == 10
+ assert len(pline.Station1.nodes) == SEVEN
+ assert len(pline.Station2.nodes) == TEN
def test_copy(self):
+ """Test copy."""
pline = ProductionLineGraph()
pline.new_cell(name="Station1")
pline.new_cell(name="Station2")
@@ -543,6 +655,7 @@ def test_copy(self):
assert pline.cell_order == copyline.cell_order
def test_within_edges_with_empty_cells_raises_error(self):
+ """Test error."""
# Setup
pline = ProductionLineGraph()
@@ -551,6 +664,7 @@ def test_within_edges_with_empty_cells_raises_error(self):
pline.within_adjacency
def test_between_edges_with_empty_cells_raises_error(self):
+ """Test error."""
# Setup
pline = ProductionLineGraph()
@@ -559,6 +673,7 @@ def test_between_edges_with_empty_cells_raises_error(self):
print(pline.between_adjacency)
def test_within_edges_adjacency_matrix(self):
+ """Test within amat."""
# Setup
nx_graph = nx.DiGraph(
[("1", "2"), ("1", "3"), ("1", "4"), ("2", "5"), ("2", "6"), ("5", "6")]
@@ -583,9 +698,11 @@ def test_within_edges_adjacency_matrix(self):
assert (
within_amat.loc["cell1_2", :].sum() == 0
and pline.ground_truth.loc["cell1_2", :].sum() != 0
- )
+ ) # type: ignore
def test_between_edges_adjacency_matrix(self):
+ """Test between edges amat."""
+ TWO = 2
# Setup
nx_graph = nx.DiGraph(
[("1", "2"), ("1", "3"), ("1", "4"), ("2", "5"), ("2", "6"), ("5", "6")]
@@ -599,13 +716,14 @@ def test_between_edges_adjacency_matrix(self):
# Assert
assert (
- between_amat.loc["cell1_2", :].sum() == 2
- and pline.ground_truth.loc["cell1_2", :].sum() == 2
- )
+ between_amat.loc["cell1_2", :].sum() == TWO
+ and pline.ground_truth.loc["cell1_2", :].sum() == TWO
+ ) # type: ignore
assert between_amat.loc[pline.cell1.nodes, pline.cell1.nodes].sum().sum() == 0
assert between_amat.loc[pline.cell2.nodes, pline.cell2.nodes].sum().sum() == 0
def test_interventional_drf_error(self):
+ """Test interventional drf."""
testline = ProductionLineGraph()
with pytest.raises(ValueError):
testline.sample_from_interventional_drf()
diff --git a/tests/test_models_fcm.py b/tests/test_models_fcm.py
index 14e7292..7824d01 100644
--- a/tests/test_models_fcm.py
+++ b/tests/test_models_fcm.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -24,8 +25,19 @@
class TestFCM:
+ """Test FCM class.
+
+ Returns:
+ _type_: _description_
+ """
+
@pytest.fixture(scope="class")
def example_fcm(self):
+ """Set up example fcm.
+
+ Returns:
+ _type_: _description_
+ """
x, y, z = symbols("x,y,z")
eq_x = Eq(x, Uniform("error", left=-1, right=1))
@@ -35,12 +47,17 @@ def example_fcm(self):
eq_list = [eq_x, eq_y, eq_z]
example_fcm = FCM(name="example_fcm", seed=2023)
- example_fcm.input_fcm(eq_list)
+ example_fcm.input_fcm(eq_list) # type: ignore
return example_fcm
@pytest.fixture(scope="class")
def medium_example_fcm(self) -> FCM:
+ """Set up medium size exmaple.
+
+ Returns:
+ FCM: _description_
+ """
v, x, y, z = symbols("v,x,y,z")
eq_x = Eq(x, Normal("error", 0, 1))
@@ -51,24 +68,29 @@ def medium_example_fcm(self) -> FCM:
eq_list = [eq_v, eq_x, eq_y, eq_z]
test_fcm = FCM(name="testing", seed=2023)
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
return test_fcm
def test_instance_is_created(self):
+ """TEst instance created."""
h = FCM(name="mymodel", seed=1234)
assert isinstance(h, FCM)
def test_input_fcm_works(self, example_fcm):
+ """Test input works."""
+ THREE = 3
# Act
x, y = symbols("x,y")
# Assert
assert len(example_fcm.source_nodes) == 1
- assert example_fcm.num_nodes == 3
- assert example_fcm.num_edges == 3
+ assert example_fcm.num_nodes == THREE
+ assert example_fcm.num_edges == THREE
assert (x, y) in example_fcm.edges
def test_empty_graph_works(self):
+ """Test empty graph works."""
+ THREE = 3
# Arrange
x, y, z = symbols("x,y,z")
@@ -78,24 +100,31 @@ def test_empty_graph_works(self):
# Act
test_fcm = FCM()
- test_fcm.input_fcm([eq_x, eq_y, eq_z])
+ test_fcm.input_fcm([eq_x, eq_y, eq_z]) # type: ignore
df = test_fcm.sample(size=5)
# Assert
- assert test_fcm.num_nodes == 3
+ assert test_fcm.num_nodes == THREE
assert test_fcm.num_edges == 0
assert df.shape == (5, 3)
def test_draw_without_noise_works(self, example_fcm):
+ """Test w/o noise.
+
+ Args:
+ example_fcm (_type_): _description_
+ """
+ TEN = 10
# Act
df = example_fcm.sample(size=10, additive_gaussian_noise=False)
# Assert
- assert len(df) == 10
+ assert len(df) == TEN
def test_draw_with_noise_works(self, example_fcm):
+ """Test w/o noise."""
# Act
df_without_noise = example_fcm.sample(size=10, additive_gaussian_noise=False)
df_with_noise = example_fcm.sample(size=10, additive_gaussian_noise=True)
@@ -106,6 +135,11 @@ def test_draw_with_noise_works(self, example_fcm):
assert not np.allclose(df_without_noise[col], df_with_noise[col])
def test_draw_from_dataframe(self, example_fcm: FCM):
+ """Test draw from df.
+
+ Args:
+ example_fcm (FCM): _description_
+ """
# Arrange
source_df = pd.DataFrame()
source_df["x"] = [0, 1.0, 10.0]
@@ -125,6 +159,7 @@ def test_draw_from_dataframe(self, example_fcm: FCM):
)
def test_draw_from_wrong_dataframe_raises_assertionerror(self, example_fcm):
+ """Test draw error."""
source_df = pd.DataFrame()
source_df["AAA"] = [0, 1.0, 10.0]
@@ -133,6 +168,8 @@ def test_draw_from_wrong_dataframe_raises_assertionerror(self, example_fcm):
example_fcm.sample(size=3, source_df=source_df, additive_gaussian_noise=False)
def test_specify_individual_noise(self):
+ """Test individual noise."""
+ THREE = 3
# Arrange
x, y, z = symbols("x,y,z")
@@ -143,7 +180,7 @@ def test_specify_individual_noise(self):
eq_list = [eq_x, eq_y, eq_z]
test_fcm = FCM(name="testing", seed=2023)
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
# Act
@@ -151,10 +188,11 @@ def test_specify_individual_noise(self):
# Assert
- assert test_fcm.num_edges == 3
+ assert test_fcm.num_edges == THREE
assert df.shape == (10, 3)
def test_order_in_eval_always_correct(self):
+ """Test order is correct when evaluating."""
# Arrange
v, w, x, y, z = symbols("v,w,x,y,z")
eq_v = Eq(v, Uniform("noise", left=0.2, right=0.8))
@@ -167,7 +205,7 @@ def test_order_in_eval_always_correct(self):
# Act
test_fcm = FCM()
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
df = test_fcm.sample(size=10)
# Assert
assert all(np.isclose(df["x"], 27 - df["v"]))
@@ -175,6 +213,8 @@ def test_order_in_eval_always_correct(self):
assert all(np.isclose(df["z"], (df["x"] + df["y"]) / df["v"]))
def test_select_scale_parameter_via_snr(self):
+ """Test scale param."""
+ THREE = 3
# Arrange
x, y, z = symbols("x,y,z")
sigma = Symbol("sigma", positive=True)
@@ -186,7 +226,7 @@ def test_select_scale_parameter_via_snr(self):
eq_list = [eq_x, eq_y, eq_z]
test_fcm = FCM(name="testing", seed=2023)
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
# Act
@@ -194,10 +234,11 @@ def test_select_scale_parameter_via_snr(self):
# Assert
- assert test_fcm.num_edges == 3
+ assert test_fcm.num_edges == THREE
assert df.shape == (10, 3)
def test_select_scale_parameter_via_snr_gives_error_when_not_additive(self):
+ """TEst sclae param set via snr."""
# Arrange
x, y, z = symbols("x,y,z")
sigma = Symbol("sigma", positive=True)
@@ -209,13 +250,19 @@ def test_select_scale_parameter_via_snr_gives_error_when_not_additive(self):
eq_list = [eq_x, eq_y, eq_z]
test_fcm = FCM(name="testing", seed=2023)
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
# Assert
with pytest.raises(ValueError):
test_fcm.sample(size=10, snr=0.6)
def test_data_frame_with_fewer_columns_than_source_nodes(self, medium_example_fcm: FCM):
+ """Test df with few cols.
+
+ Args:
+ medium_example_fcm (FCM): _description_
+ """
+ FOUR = 4
# Act
source_df = pd.DataFrame(
{
@@ -226,10 +273,15 @@ def test_data_frame_with_fewer_columns_than_source_nodes(self, medium_example_fc
df = medium_example_fcm.sample(size=10, snr=0.6, source_df=source_df)
# Assert
- assert medium_example_fcm.num_edges == 4
+ assert medium_example_fcm.num_edges == FOUR
assert df.shape == (10, 4)
def test_data_frame_too_few_rows(self, medium_example_fcm: FCM):
+ """Test df too few rows.
+
+ Args:
+ medium_example_fcm (FCM): _description_
+ """
# Act
source_df = pd.DataFrame(
{
@@ -243,6 +295,7 @@ def test_data_frame_too_few_rows(self, medium_example_fcm: FCM):
medium_example_fcm.sample(size=10, snr=0.6, source_df=source_df)
def test_polynomial_equation_works(self):
+ """Test polynomial eq."""
# Arrange
x, y = symbols("x,y")
@@ -250,12 +303,14 @@ def test_polynomial_equation_works(self):
eq_y = Eq(y, x**2 - 2 * x + 5)
test_fcm = FCM()
- test_fcm.input_fcm([eq_x, eq_y])
+ test_fcm.input_fcm([eq_x, eq_y]) # type: ignore
# Act
df = test_fcm.sample(size=5)
assert all(df["y"] == df["x"] ** 2 - 2 * df["x"] + 5)
def test_display_functions_works(self):
+ """Test display."""
+ FOUR = 4
# Arrange
v, x, y, z = symbols("v,x,y,z")
@@ -267,17 +322,18 @@ def test_display_functions_works(self):
eq_list = [eq_v, eq_x, eq_y, eq_z]
test_fcm = FCM(name="testing", seed=2023)
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
# Act
functions_dict = test_fcm.display_functions()
# Assert
assert isinstance(functions_dict, dict)
- assert len(functions_dict) == 4
+ assert len(functions_dict) == FOUR
assert str(functions_dict[x]) == str(functions_dict[v]) == "error"
assert functions_dict[y] == eq_y.rhs
assert functions_dict[z] == eq_z.rhs
def test_show_individual_function(self):
+ """Test indiv. functions."""
# Arrange
v, x, y, z = symbols("v,x,y,z")
@@ -289,16 +345,17 @@ def test_show_individual_function(self):
eq_list = [eq_v, eq_x, eq_y, eq_z]
test_fcm = FCM(name="testing", seed=2023)
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
# Assert
assert isinstance(test_fcm.function_of(node=x), dict)
with pytest.raises(AssertionError):
- test_fcm.function_of(node="x")
+ test_fcm.function_of(node="x") # type: ignore
with pytest.raises(AssertionError):
- test_fcm.function_of(node="m")
+ test_fcm.function_of(node="m") # type: ignore
assert test_fcm.function_of(node=y) == {y: eq_y.rhs}
def test_single_hard_intervention(self):
+ """Test hard inter."""
# Arrange
x, y, z = symbols("x,y,z")
@@ -309,7 +366,7 @@ def test_single_hard_intervention(self):
eq_list = [eq_x, eq_y, eq_z]
test_fcm = FCM()
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
# Act
test_fcm.intervene_on({y: 2.5})
@@ -324,7 +381,9 @@ def test_single_hard_intervention(self):
assert np.isclose(test_df["y"].mean(), 2.5)
def test_single_soft_intervention(self):
+ """TEst soft interv."""
# Arrange
+ point_three = 0.3
x, y, z = symbols("x,y,z")
eq_x = Eq(x, Normal("error", 0, 1))
@@ -334,7 +393,7 @@ def test_single_soft_intervention(self):
eq_list = [eq_x, eq_y, eq_z]
test_fcm = FCM()
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
# Act
test_fcm.intervene_on({z: Uniform("noise", left=-0.3, right=0.3)})
@@ -346,9 +405,10 @@ def test_single_soft_intervention(self):
assert test_fcm.interventions[0] == "do([z])"
assert len(test_fcm.mutilated_dags) == 1
assert len(test_fcm.mutilated_dags["do([z])"].edges()) == test_fcm.num_edges - 2
- assert -0.3 < test_df["z"].min() and test_df["z"].min() < 0.3
+ assert -point_three < test_df["z"].min() and test_df["z"].min() < point_three
def test_multiple_interventions_at_once(self):
+ """Test mutliple interv."""
# Arrange
v, x, y, z = symbols("v,x,y,z")
@@ -360,7 +420,7 @@ def test_multiple_interventions_at_once(self):
eq_list = [eq_v, eq_x, eq_y, eq_z]
test_fcm = FCM()
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
# Act
test_fcm.intervene_on({v: 5, z: Normal("error", 0, 1)})
@@ -374,6 +434,8 @@ def test_multiple_interventions_at_once(self):
assert np.isclose(test_df["v"].mean(), 5)
def test_multiple_interventions_sequentually(self):
+ """Test interv. sequentially."""
+ TWO = 2
# Arrange
v, x, y, z = symbols("v,x,y,z")
@@ -385,7 +447,7 @@ def test_multiple_interventions_sequentually(self):
eq_list = [eq_v, eq_x, eq_y, eq_z]
test_fcm = FCM()
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
# Act
test_fcm.intervene_on({v: 5})
@@ -395,13 +457,14 @@ def test_multiple_interventions_sequentually(self):
test_fcm.interventional_sample(size=5, which_intervention=1)
# Assert
- assert len(test_fcm.interventions) == 2
- assert len(test_fcm.mutilated_dags) == 2
+ assert len(test_fcm.interventions) == TWO
+ assert len(test_fcm.mutilated_dags) == TWO
assert len(test_fcm.mutilated_dags["do([v])"].edges()) == test_fcm.num_edges - 1
assert len(test_fcm.mutilated_dags["do([z])"].edges()) == test_fcm.num_edges - 2
assert np.isclose(test_df_1["v"].mean(), 5)
def test_reproducability(self):
+ """Test repoducability."""
# Arrange
v, x, y, z, sigma = symbols("v,x,y,z,sigma")
@@ -414,7 +477,7 @@ def test_reproducability(self):
# Act
test_fcm = FCM(name="testing", seed=2023)
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
source_df = pd.DataFrame(
{
@@ -425,7 +488,7 @@ def test_reproducability(self):
df_a = test_fcm.sample(size=5, snr=2 / 3, additive_gaussian_noise=True, source_df=source_df)
test_fcm = FCM(name="testing", seed=2023)
- test_fcm.input_fcm(eq_list)
+ test_fcm.input_fcm(eq_list) # type: ignore
df_b = test_fcm.sample(size=5, snr=2 / 3, additive_gaussian_noise=True, source_df=source_df)
diff --git a/tests/test_pdag.py b/tests/test_pdag.py
index f73d7c1..b9d288c 100644
--- a/tests/test_pdag.py
+++ b/tests/test_pdag.py
@@ -1,4 +1,5 @@
-""" Utility classes and functions related to causalAssembly.
+"""Utility classes and functions related to causalAssembly.
+
Copyright (c) 2023 Robert Bosch GmbH
This program is free software: you can redistribute it and/or modify
@@ -20,10 +21,17 @@
from causalAssembly.pdag import PDAG, dag2cpdag
+TWO = 2
+THREE = 3
+FOUR = 4
+
class TestPDAG:
+ """Test PDAG class."""
+
@pytest.fixture(scope="class")
def mixed_pdag(self) -> PDAG:
+ """Set up pdag."""
pdag = PDAG(
nodes=["A", "B", "C"],
dir_edges=[("A", "B"), ("A", "C")],
@@ -32,43 +40,55 @@ def mixed_pdag(self) -> PDAG:
return pdag
def test_instance_is_created(self):
+ """Test instance."""
pdag = PDAG(nodes=["A", "B", "C"])
assert isinstance(pdag, PDAG)
def test_dir_edges(self):
+ """Test dir edges."""
pdag = PDAG(nodes=["A", "B", "C"], dir_edges=[("A", "B"), ("A", "C")])
- assert pdag.num_dir_edges == 2
+ assert pdag.num_dir_edges == TWO
assert pdag.num_undir_edges == 0
assert set(pdag.dir_edges) == {("A", "B"), ("A", "C")}
def test_undir_edges(self):
+ """Test undir edges."""
pdag = PDAG(nodes=["A", "B", "C"], undir_edges=[("A", "B"), ("A", "C")])
assert pdag.num_dir_edges == 0
- assert pdag.num_undir_edges == 2
+ assert pdag.num_undir_edges == TWO
assert set(pdag.undir_edges) == {("A", "B"), ("A", "C")}
def test_mixed_edges(self, mixed_pdag: PDAG):
- assert mixed_pdag.num_dir_edges == 2
+ """Test mixed edges."""
+ assert mixed_pdag.num_dir_edges == TWO
assert mixed_pdag.num_undir_edges == 1
assert set(mixed_pdag.dir_edges) == {("A", "B"), ("A", "C")}
assert set(mixed_pdag.undir_edges) == {("B", "C")}
def test_children(self, mixed_pdag: PDAG):
+ """Test child edges."""
assert mixed_pdag.children(node="A") == {"B", "C"}
assert mixed_pdag.children(node="B") == set()
assert mixed_pdag.children(node="C") == set()
def test_parents(self, mixed_pdag: PDAG):
+ """Test parent edges."""
assert mixed_pdag.parents(node="A") == set()
assert mixed_pdag.parents(node="B") == {"A"}
assert mixed_pdag.parents(node="C") == {"A"}
def test_neighbors(self, mixed_pdag: PDAG):
+ """Test neighbors."""
assert mixed_pdag.neighbors(node="C") == {"B", "A"}
assert mixed_pdag.undir_neighbors(node="C") == {"B"}
assert mixed_pdag.is_adjacent(i="B", j="C")
def test_from_pandas_adjacency(self, mixed_pdag: PDAG):
+ """Test import from pandas.
+
+ Args:
+ mixed_pdag (PDAG): _description_
+ """
amat = pd.DataFrame(
[[0, 1, 1], [0, 0, 1], [0, 1, 0]],
columns=["A", "B", "C"],
@@ -80,6 +100,11 @@ def test_from_pandas_adjacency(self, mixed_pdag: PDAG):
assert from_pandas_pdag.num_undir_edges == mixed_pdag.num_undir_edges
def test_remove_edge(self, mixed_pdag: PDAG):
+ """Test remove edges.
+
+ Args:
+ mixed_pdag (PDAG): _description_
+ """
assert ("A", "C") in mixed_pdag.dir_edges
mixed_pdag.remove_edge("A", "C")
assert ("A", "C") not in mixed_pdag.dir_edges
@@ -87,6 +112,11 @@ def test_remove_edge(self, mixed_pdag: PDAG):
mixed_pdag.remove_edge("B", "A")
def test_change_undir_edge_to_dir_edge(self, mixed_pdag: PDAG):
+ """Test change.
+
+ Args:
+ mixed_pdag (PDAG): _description_
+ """
assert ("B", "C") in mixed_pdag.undir_edges or (
"C",
"B",
@@ -98,25 +128,41 @@ def test_change_undir_edge_to_dir_edge(self, mixed_pdag: PDAG):
assert ("C", "B") not in mixed_pdag.undir_edges
def test_remove_node(self, mixed_pdag: PDAG):
+ """Test remove nodes.
+
+ Args:
+ mixed_pdag (PDAG): _description_
+ """
assert "C" in mixed_pdag.nodes
mixed_pdag.remove_node("C")
assert "C" not in mixed_pdag.nodes
def test_to_dag(self, mixed_pdag: PDAG):
+ """Test conversion to DAG.
+
+ Args:
+ mixed_pdag (PDAG): _description_
+ """
dag = mixed_pdag.to_dag()
assert nx.is_directed_acyclic_graph(dag)
assert set(mixed_pdag.dir_edges).issubset(set(dag.edges))
def test_adjacency_matrix(self, mixed_pdag: PDAG):
+ """Test return of adjacency matrix.
+
+ Args:
+ mixed_pdag (PDAG): _description_
+ """
amat = mixed_pdag.adjacency_matrix
assert amat.shape[0] == amat.shape[1] == mixed_pdag.nnodes
assert amat.sum().sum() == mixed_pdag.num_dir_edges + 2 * mixed_pdag.num_undir_edges
def test_dag2cpdag(self):
+ """Test conversion from DAG to CPDAG."""
dag1 = nx.DiGraph([("1", "2"), ("2", "3"), ("3", "4")])
cpdag1 = dag2cpdag(dag=dag1)
assert cpdag1.num_dir_edges == 0
- assert cpdag1.num_undir_edges == 3
+ assert cpdag1.num_undir_edges == THREE
dag2 = nx.DiGraph([("1", "3"), ("2", "3")])
cpdag2 = dag2cpdag(dag=dag2)
@@ -125,10 +171,11 @@ def test_dag2cpdag(self):
dag3 = nx.DiGraph([("1", "3"), ("2", "3"), ("1", "4")])
cpdag3 = dag2cpdag(dag=dag3)
- assert cpdag3.num_dir_edges == 2
+ assert cpdag3.num_dir_edges == TWO
assert cpdag3.num_undir_edges == 1
def test_example_a_to_allDAGs(self):
+ """Test example PDAG to allDAGs."""
# Set up CPDAG: a - c - b -> MEC has 3 Members
pdag = nx.Graph([("a", "c"), ("b", "c")])
amat = nx.to_pandas_adjacency(pdag)
@@ -139,10 +186,11 @@ def test_example_a_to_allDAGs(self):
# Act
all_dags = example_pdag.to_allDAGs()
- assert len(all_dags) == 3
+ assert len(all_dags) == THREE
assert all([isinstance(dag, nx.DiGraph) for dag in all_dags])
def test_example_b_to_allDAGs(self):
+ """Test example PDAG to allDAGs."""
# Set up CPDAG: a - (b,c,d) -> MEC has 4 Members
pdag = nx.Graph([("a", "b"), ("a", "c"), ("a", "d")])
amat = nx.to_pandas_adjacency(pdag)
@@ -153,10 +201,11 @@ def test_example_b_to_allDAGs(self):
# Act
all_dags = example_pdag.to_allDAGs()
- assert len(all_dags) == 4
+ assert len(all_dags) == FOUR
assert all([isinstance(dag, nx.DiGraph) for dag in all_dags])
def test_empty_graph_to_allDAGs(self):
+ """Test empty graph to allDAGs."""
# Set up empty PDAG, has exaclty one DAG that is the same as the PDAG.
pdag = nx.Graph()
pdag.add_nodes_from(["a", "b", "c", "d"])
@@ -173,6 +222,7 @@ def test_empty_graph_to_allDAGs(self):
assert set(all_dags[0].nodes) == set(pdag.nodes)
def test_to_random_dag(self):
+ """Test to random DAG."""
# Set up CPDAG: a - (b,c,d) -> MEC has 4 Members
pdag = nx.Graph([("a", "b"), ("a", "c"), ("a", "d")])
amat = nx.to_pandas_adjacency(pdag)