Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contextualized/baselines/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def predict_w(self, model, dataloader, project_to_dag=True):
preds = self.predict(model, dataloader)
W = model.W.detach() * model.diag_mask
if project_to_dag:
W = torch.tensor(project_to_dag_torch(W.numpy(force=True))[0])
W = torch.tensor(project_to_dag_torch(W.numpy(force=True)))
W_batch = W.unsqueeze(0).expand(len(preds), -1, -1)
return W_batch.numpy()

Expand Down
52 changes: 41 additions & 11 deletions contextualized/dags/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,45 @@ def _simulate_single_equation(X, w, scale):
X[:, j] = _simulate_single_equation(X[:, parents], W[parents, j], scale_vec[j])
return X

def is_dag(W):
G = ig.Graph.Weighted_Adjacency(W.tolist())
return G.is_dag()


def trim_params(w, thresh=0.2):
return w * (np.abs(w) > thresh)


def project_to_dag_torch(w, thresh=0.0):
"""
Project a weight matrix to the closest DAG in Frobenius norm.
"""

if is_dag(w):
return w

w_dag = w.copy()
# Easy case first: remove diagnoal entries.
w_dag *= 1 - np.eye(w.shape[0])

# First, remove edges with weights smaller than the thresh.
w_dag = trim_params(w_dag, thresh)

# Sort nodes by magnitude of edges pointing out.
order = np.argsort(np.abs(w_dag).sum(axis=1))[::-1]

# Re-order
w_dag = w_dag[order, :][:, order]

# Keep only forward edges (i.e. upper triangular part).
w_dag = np.triu(w_dag)

# Return to original order
w_dag = w_dag[np.argsort(order), :][:, np.argsort(order)]

assert is_dag(w_dag)
return w_dag


def break_symmetry(w):
for i in range(w.shape[0]):
Expand All @@ -101,7 +140,7 @@ def break_symmetry(w):


# w is the weighted adjacency matrix
def project_to_dag_torch(w):
def project_to_dag_search(w):
if is_dag(w):
return w, 0.0

Expand Down Expand Up @@ -152,13 +191,4 @@ def binary_search(arr, low, high, w): # low and high are indices
w_dag[i][j] = 0.0

assert is_dag(w_dag)
return w_dag, thresh


def is_dag(W):
G = ig.Graph.Weighted_Adjacency(W.tolist())
return G.is_dag()


def trim_params(w, thresh=0.2):
return w * (np.abs(w) > thresh)
return w_dag
2 changes: 1 addition & 1 deletion contextualized/dags/lightning_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _format_params(self, w_preds, **kwargs):
w_preds = self._project_factor_graph_to_var(w_preds)
if kwargs.get("project_to_dag", False):
try:
w_preds = np.array([project_to_dag_torch(w)[0] for w in w_preds])
w_preds = np.array([project_to_dag_torch(w) for w in w_preds])
except:
print("Error, couldn't project to dag. Returning normal predictions.")
return trim_params(w_preds, thresh=kwargs.get("threshold", 0.0))
Expand Down
44 changes: 44 additions & 0 deletions contextualized/dags/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,50 @@
from contextualized.dags.trainers import GraphTrainer
from contextualized.dags.losses import mse_loss as mse

class TestProjectToDag(unittest.TestCase):
"""
Test that the project_to_dag function works to create a DAG from a directed cyclic graph.
"""
def __init__(self, *args, **kwargs):
super(TestProjectToDag, self).__init__(*args, **kwargs)

def setUp(self):
"""
Shared unit test setup code.
"""
pass

def test_project_to_dag(self):
"""
Test that the project_to_dag function works to create a DAG from a directed cyclic graph.
"""
# Create a cyclic graph.
W = np.zeros((5, 5))
W[0, 1] = 1
W[1, 2] = 1
W[2, 3] = 1
W[3, 4] = 1
W[4, 0] = 1

# Project to a DAG.
dag = graph_utils.project_to_dag_torch(W)
assert graph_utils.is_dag(dag)

def test_project_to_dag_from_dag(self):
"""
Test that the project_to_dag function works to create a DAG from a DAG.
"""
# Create a DAG.
W = np.zeros((5, 5))
W[0, 1] = 1
W[1, 2] = 1
W[2, 3] = 1
W[3, 4] = 1

# Project to a DAG.
dag = graph_utils.project_to_dag_torch(W)
assert graph_utils.is_dag(dag)


class TestNOTMAD(unittest.TestCase):
"""Unit tests for NOTMAD."""
Expand Down
146 changes: 146 additions & 0 deletions docs/demos/project_to_dag_time.ipynb

Large diffs are not rendered by default.