Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ jobs:
run: |
sudo apt-get install -y libcurl4-openssl-dev libssl-dev
Rscript -e 'install.packages(c("devtools", "rextendr"), repos="https://cloud.r-project.org")'

- name: Add Rust Windows target
if: runner.os == 'Windows'
run: rustup target add x86_64-pc-windows-gnu

# Build all bindings (Rust, Python, WASM, R)
- name: Build all components
Expand Down
8 changes: 4 additions & 4 deletions python_bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,15 @@ impl PyRustDAG {
.map_err(PyValueError::new_err)
}

#[pyo3(signature = (start, end, include_latents=false))]
#[pyo3(signature = (starts, ends, include_latents=false))]
pub fn minimal_dseparator(
&self,
start: String,
end: String,
starts: Vec<String>,
ends: Vec<String>,
include_latents: bool,
) -> PyResult<Option<std::collections::HashSet<String>>> {
self.inner
.minimal_dseparator(&start, &end, include_latents)
.minimal_dseparator(starts.clone(), ends.clone(), include_latents)
.map_err(PyValueError::new_err)
}

Expand Down
24 changes: 12 additions & 12 deletions python_bindings/tests/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,45 +21,45 @@ def test_parents_children(self, dag):
assert dag.get_children("X") == ["Y"]

def test_minimal_dseparator(self):
# Test case: A → B C
# Test case: A → B C
dag1 = DAG()
dag1.add_edges_from([("A", "B"), ("B", "C")])
assert dag1.minimal_dseparator("A", "C") == {"B"}
assert dag1.minimal_dseparator(["A"], ["C"]) == {"B"}

# Test case: A → B C, B → D, A → E, E → D
# Test case: A → B C, C → D, A → E, E → D
dag2 = DAG()
dag2.add_edges_from([("A", "B"), ("B", "C"), ("C", "D"), ("A", "E"), ("E", "D")])
assert dag2.minimal_dseparator("A", "D") == {"C", "E"}
assert dag2.minimal_dseparator(["A"], ["D"]) == {"C", "E"}

# Test case: B → A, B → C, A → D, D → C, A → E, C → E
dag3 = DAG()
dag3.add_edges_from([("B", "A"), ("B", "C"), ("A", "D"), ("D", "C"), ("A", "E"), ("C", "E")])
assert dag3.minimal_dseparator("A", "C") == {"B", "D"}
assert dag3.minimal_dseparator(["A"], ["C"]) == {"B", "D"}

# Test with latents
dag_lat1 = DAG()
dag_lat1.add_nodes_from(["A", "B", "C"], latent=[False, True, False])
dag_lat1.add_edges_from([("A", "B"), ("B", "C")])
assert dag_lat1.minimal_dseparator("A", "C") is None
# assert dag_lat1.minimal_dseparator("A", "C", include_latents=True) == {"B"}
assert dag_lat1.minimal_dseparator(["A"], ["C"]) is None
# assert dag_lat1.minimal_dseparator(["A"], ["C"], include_latents=True) == {"B"}

dag_lat2 = DAG()
dag_lat2.add_nodes_from(["A", "B", "C", "D"], latent=[False, True, False, False])
dag_lat2.add_edges_from([("A", "D"), ("D", "B"), ("B", "C")])
assert dag_lat2.minimal_dseparator("A", "C") == {"D"}
assert dag_lat2.minimal_dseparator(["A"], ["C"]) == {"D"}

dag_lat3 = DAG()
dag_lat3.add_nodes_from(["A", "B", "C", "D"], latent=[False, True, False, False])
dag_lat3.add_edges_from([("A", "B"), ("B", "D"), ("D", "C")])
assert dag_lat3.minimal_dseparator("A", "C") == {"D"}
assert dag_lat3.minimal_dseparator(["A"], ["C"]) == {"D"}

dag_lat4 = DAG()
dag_lat4.add_nodes_from(["A", "B", "C", "D"], latent=[False, False, False, True])
dag_lat4.add_edges_from([("A", "B"), ("B", "C"), ("A", "D"), ("D", "C")])
assert dag_lat4.minimal_dseparator("A", "C") is None
assert dag_lat4.minimal_dseparator(["A"], ["C"]) is None

# Test adjacent nodes (should raise error)
dag5 = DAG()
dag5.add_edges_from([("A", "B")])
with pytest.raises(ValueError, match="No possible separators because start and end are adjacent"):
dag5.minimal_dseparator("A", "B")
with pytest.raises(ValueError, match="No possible separators because A and B are adjacent"):
dag5.minimal_dseparator(["A"], ["B"])
2 changes: 1 addition & 1 deletion r_bindings/causalgraphs/R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ RDAG$are_neighbors <- function(start, end) .Call(wrap__RDAG__are_neighbors, self

RDAG$get_ancestral_graph <- function(nodes) .Call(wrap__RDAG__get_ancestral_graph, self, nodes)

RDAG$minimal_dseparator <- function(start, end, include_latents) .Call(wrap__RDAG__minimal_dseparator, self, start, end, include_latents)
RDAG$minimal_dseparator <- function(starts, ends, include_latents) .Call(wrap__RDAG__minimal_dseparator, self, starts, ends, include_latents)

#' @export
`$.RDAG` <- function (self, name) { func <- RDAG[[name]]; environment(func) <- environment(); func }
Expand Down
4 changes: 2 additions & 2 deletions r_bindings/causalgraphs/src/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ name = 'rcausalgraphs'

[dependencies]

rust_core = { git = "https://github.com/pgmpy/causalgraphs.git", branch = "main", package = "rust_core" }
# rust_core = { git = "https://github.com/pgmpy/causalgraphs.git", branch = "main", package = "rust_core" }

# For local development, comment out the Git line above and uncomment this:
# rust_core = { path = "../../../../rust_core" }
rust_core = { path = "../../../../rust_core" }

extendr-api = '*'
6 changes: 4 additions & 2 deletions r_bindings/causalgraphs/src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,10 @@ impl RDAG {
/// @param end Ending node
/// @param include_latents Whether to include latents (default: FALSE)
/// @export
fn minimal_dseparator(&self, start: String, end: String, include_latents: Option<bool>) -> extendr_api::Result<Nullable<Strings>> {
let result = self.inner.minimal_dseparator(&start, &end, include_latents.unwrap_or(false))
fn minimal_dseparator(&self, starts: Strings, ends: Strings, include_latents: Option<bool>) -> extendr_api::Result<Nullable<Strings>> {
let starts_vec: Vec<String> = starts.iter().map(|s| s.to_string()).collect();
let ends_vec: Vec<String> = ends.iter().map(|s| s.to_string()).collect();
let result = self.inner.minimal_dseparator(starts_vec.clone(), ends_vec.clone(), include_latents.unwrap_or(false))
.map_err(|e| Error::Other(e.to_string()))?;
match result {
Some(set) => {
Expand Down
81 changes: 53 additions & 28 deletions rust_core/src/dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,30 +465,44 @@ impl RustDAG {
/// Tian, Paz, Pearl (1998), *Finding Minimal d-Separators*.
pub fn minimal_dseparator(
&self,
start: &str,
end: &str,
starts: Vec<String>,
ends: Vec<String>,
include_latents: bool,
) -> Result<Option<HashSet<String>>, String> {
// Example: For DAG A→B←C, B→D, trying to separate A and C
// Adjacent nodes can't be separated by any conditioning set
if self.has_edge(start, end) || self.has_edge(end, start) {
return Err("No possible separators because start and end are adjacent".to_string());
// Validate inputs
if starts.is_empty() || ends.is_empty() {
return Ok(Some(HashSet::new()));
}

// Create ancestral graph containing only ancestors of start and end
// Example: For separating A and D in A→B←C, B→D, ancestral graph = {A, B, C, D}
let ancestral_graph = self.get_ancestral_graph(vec![start.to_string(), end.to_string()])?;
// Check for adjacent pairs - if any start-end pair is adjacent, no separator exists
for start in &starts {
for end in &ends {
if self.has_edge(start, end) || self.has_edge(end, start) {
return Err(format!(
"No possible separators because {} and {} are adjacent",
start, end
));
}
}
}

// Initial separator: all parents of both nodes (theoretical upper bound)
// Example: parents(A)={} ∪ parents(D)={B} → separator = {B}
let mut separator: HashSet<String> = self
.get_parents(start)?
.into_iter()
.chain(self.get_parents(end)?.into_iter())
.collect();

// Create ancestral graph containing only ancestors of all starts and ends
let mut all_nodes = starts.clone();
all_nodes.extend(ends.clone());
let ancestral_graph = self.get_ancestral_graph(all_nodes)?;

// Initial separator: all parents of all start and end nodes
let mut separator: HashSet<String> = HashSet::new();

for start in &starts {
separator.extend(self.get_parents(start)?);
}
for end in &ends {
separator.extend(self.get_parents(end)?);
}

// Replace latent variables with their observable parents
// Example: If B were latent with parent L, replace B with L in separator
if !include_latents {
let mut changed = true;
while changed {
Expand All @@ -507,21 +521,32 @@ impl RustDAG {
}
}

separator.remove(start);
separator.remove(end);
// Remove starts and ends from separator (can't separate a node from itself)
for start in &starts {
separator.remove(start);
}
for end in &ends {
separator.remove(end);
}

// Helper function to check if all start-end pairs are d-separated
let check_all_separated = |sep: &[String]| -> Result<bool, String> {
for start in &starts {
for end in &ends {
if ancestral_graph.is_dconnected(start, end, Some(sep.to_vec()), include_latents)? {
return Ok(false); // Found a connected pair
}
}
}
Ok(true) // All pairs are separated
};

// Sanity check: if our "guaranteed" separator doesn't work, no separator exists
if ancestral_graph.is_dconnected(
start,
end,
Some(separator.iter().cloned().collect()),
include_latents,
)? {
if !check_all_separated(&separator.iter().cloned().collect::<Vec<_>>())? {
return Ok(None);
}

// Greedy minimization: remove each node if separation still holds without it
// Example: If separator = {B, C} but {B} alone separates A from D, remove C
let mut minimal_separator = separator.clone();
for u in separator {
let test_separator: Vec<String> = minimal_separator
Expand All @@ -530,8 +555,8 @@ impl RustDAG {
.filter(|x| x != &u)
.collect();

// If still d-separated WITHOUT this node, we can remove it
if !ancestral_graph.is_dconnected(start, end, Some(test_separator), include_latents)? {
// If all pairs are still d-separated WITHOUT this node, we can remove it
if check_all_separated(&test_separator)? {
minimal_separator.remove(&u);
}
}
Expand Down
Loading
Loading