Skip to content

Commit bc6e4e2

Browse files
committed
adding ArangoDB support
1 parent 55c9e41 commit bc6e4e2

File tree

16 files changed

+279
-32
lines changed

16 files changed

+279
-32
lines changed

bin/eda.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import hydra
88
import matplotlib.pyplot as plt
99
import networkx as nx
10-
import numpy as np
1110
from omegaconf import DictConfig
1211
from os.path import join as join_path
1312
import pandas as pd
@@ -159,10 +158,12 @@ def log_results(
159158

160159
mlflow.end_run()
161160

161+
162162
######################################
163163
# Main
164164
######################################
165165

166+
166167
@hydra.main(version_base=None, config_path="../conf", config_name="config")
167168
def main(config: DictConfig) -> None:
168169
"""

bin/process.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
######################################
66

77
import hydra
8-
import matplotlib.pyplot as plt
98
import networkx as nx
10-
import numpy as np
119
from omegaconf import DictConfig
1210
from os.path import join as join_path
1311
import pandas as pd
@@ -19,8 +17,11 @@
1917

2018

2119
def process_network(
22-
feature_matrix: pd.DataFrame, edge_list: pd.DataFrame, from_col: str, to_col: str,
23-
len_component: int = 5
20+
feature_matrix: pd.DataFrame,
21+
edge_list: pd.DataFrame,
22+
from_col: str,
23+
to_col: str,
24+
len_component: int = 5,
2425
) -> tuple[pd.DataFrame, pd.DataFrame]:
2526
"""
2627
Construct a graph from edge list data.
@@ -49,7 +50,7 @@ def process_network(
4950
if len(component) <= len_component:
5051
for node in component:
5152
G.remove_node(node)
52-
53+
5354
nodes = list(G.nodes)
5455
filtered_feature_matrix = feature_matrix[nodes]
5556
filtered_edge_list = nx.to_pandas_edgelist(G, source=from_col, target=to_col)
@@ -60,8 +61,8 @@ def log_results(
6061
tracking_uri: str,
6162
experiment_prefix: str,
6263
grn_name: str,
63-
feature_matrix: pd.DataFrame,
64-
edge_list: pd.DataFrame
64+
feature_matrix: pd.DataFrame,
65+
edge_list: pd.DataFrame,
6566
) -> None:
6667
"""
6768
Log experiment results to the experiment tracker.
@@ -94,14 +95,16 @@ def log_results(
9495

9596
mlflow.log_metric("num_features", len(feature_matrix.index))
9697
mlflow.log_metric("num_nodes", len(feature_matrix.columns))
97-
mlflow.log_metric("num_1st_order_relationships", len(edge_list.index))
98+
mlflow.log_metric("num_edges", len(edge_list.index))
9899

99100
mlflow.end_run()
100101

102+
101103
######################################
102104
# Main
103105
######################################
104106

107+
105108
@hydra.main(version_base=None, config_path="../conf", config_name="config")
106109
def main(config: DictConfig) -> None:
107110
"""
@@ -116,6 +119,7 @@ def main(config: DictConfig) -> None:
116119

117120
DATA_DIR = config["dir"]["data_dir"]
118121
PREPROCESS_DIR = config["dir"]["preprocessed_dir"]
122+
PROCESS_DIR = config["dir"]["processed_dir"]
119123
OUT_DIR = config["dir"]["out_dir"]
120124

121125
GRN_NAME = config["grn"]["input_dir"]
@@ -131,9 +135,11 @@ def main(config: DictConfig) -> None:
131135
feature_matrix = pd.read_csv(join_path(input_dir, FEATURE_MATRIX_FILE))
132136
edge_list = pd.read_csv(join_path(input_dir, EDGE_LIST_FILE))
133137

134-
filtered_feature_matrix, filtered_edge_list = process_network(feature_matrix, edge_list, FROM_COL, TO_COL)
138+
filtered_feature_matrix, filtered_edge_list = process_network(
139+
feature_matrix, edge_list, FROM_COL, TO_COL
140+
)
135141

136-
output_dir = join_path(DATA_DIR, OUT_DIR, GRN_NAME, "process")
142+
output_dir = join_path(DATA_DIR, OUT_DIR, GRN_NAME, PROCESS_DIR)
137143
Path(output_dir).mkdir(parents=True, exist_ok=True)
138144

139145
filtered_feature_matrix.to_csv(join_path(output_dir, FEATURE_MATRIX_FILE))
@@ -148,5 +154,6 @@ def main(config: DictConfig) -> None:
148154
filtered_edge_list,
149155
)
150156

157+
151158
if __name__ == "__main__":
152159
main()

bin/to_db.py

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#!/usr/bin/env python
2+
3+
######################################
4+
# Imports
5+
######################################
6+
7+
from adbnx_adapter import ADBNX_Adapter
8+
from arango import ArangoClient
9+
import hydra
10+
import networkx as nx
11+
from omegaconf import DictConfig
12+
from os.path import join as join_path
13+
import pandas as pd
14+
15+
16+
######################################
17+
# Functions
18+
######################################
19+
20+
21+
def log_results(
22+
tracking_uri: str,
23+
experiment_prefix: str,
24+
grn_name: str,
25+
feature_matrix: pd.DataFrame,
26+
edge_list: pd.DataFrame,
27+
) -> None:
28+
"""
29+
Log experiment results to the experiment tracker.
30+
31+
Args:
32+
tracking_uri (str):
33+
The tracking URI.
34+
experiment_prefix (str):
35+
The experiment name prefix.
36+
grn_name (str):
37+
The name of the GRN.
38+
feature_matrix (pd.DataFrame):
39+
The feature matrix.
40+
edge_list (pd.DataFrame):
41+
The edge list.
42+
"""
43+
import mlflow
44+
45+
mlflow.set_tracking_uri(tracking_uri)
46+
47+
experiment_name = f"{experiment_prefix}_process"
48+
existing_exp = mlflow.get_experiment_by_name(experiment_name)
49+
if not existing_exp:
50+
mlflow.create_experiment(experiment_name)
51+
mlflow.set_experiment(experiment_name)
52+
53+
mlflow.set_tag("grn", grn_name)
54+
55+
mlflow.log_param("grn", grn_name)
56+
57+
mlflow.log_metric("num_features", len(feature_matrix.index))
58+
mlflow.log_metric("num_nodes", len(feature_matrix.columns))
59+
mlflow.log_metric("num_edges", len(edge_list.index))
60+
61+
mlflow.end_run()
62+
63+
64+
######################################
65+
# Main
66+
######################################
67+
68+
69+
def get_graph(
70+
feature_matrix: pd.DataFrame, edge_list: pd.DataFrame, from_col: str, to_col: str
71+
) -> nx.Graph:
72+
"""
73+
Construct a graph from edge list data.
74+
75+
Args:
76+
feature_matrix (pd.DataFrame):
77+
The feature matrix.
78+
edge_list (pd.DataFrame):
79+
The edge list.
80+
from_col (str):
81+
The "from" column name.
82+
to_col (str):
83+
The "to" column name.
84+
85+
Returns:
86+
nx.Graph:
87+
The graph to write to the database.
88+
"""
89+
edges = edge_list.sort_values(from_col)
90+
91+
G = nx.from_pandas_edgelist(edges, from_col, to_col, create_using=nx.Graph())
92+
node_features = feature_matrix.to_dict()
93+
nx.set_node_attributes(G, node_features, "expression")
94+
95+
return G
96+
97+
98+
def to_db(
99+
db_host: str,
100+
db_name: str,
101+
db_username: str,
102+
db_password: str,
103+
collection: str,
104+
G: nx.Graph,
105+
) -> None:
106+
"""
107+
Write the graph to the database.
108+
109+
Args:
110+
db_host (str):
111+
The database host.
112+
db_name (str):
113+
The database name.
114+
db_username (str):
115+
The database username.
116+
db_password (str):
117+
The database password.
118+
collection (str):
119+
The database collection.
120+
G (nx.Graph):
121+
The graph.
122+
"""
123+
sys_db = ArangoClient(hosts=db_host).db(
124+
"_system", username=db_username, password=db_password
125+
)
126+
if not sys_db.has_database(db_name):
127+
sys_db.create_database(db_name)
128+
db = ArangoClient(hosts=db_host).db(
129+
db_name, username=db_username, password=db_password
130+
)
131+
132+
edges_collection = f"{collection}_edges"
133+
for db_collection in [collection, edges_collection]:
134+
if db.has_collection(db_collection):
135+
db.delete_collection(db_collection)
136+
137+
if db.has_graph(collection):
138+
db.delete_graph(collection)
139+
140+
graph_definitions = [
141+
{
142+
"edge_collection": edges_collection,
143+
"from_vertex_collections": [collection],
144+
"to_vertex_collections": [collection],
145+
}
146+
]
147+
148+
adapter = ADBNX_Adapter(db)
149+
adapter.networkx_to_arangodb(collection, G, graph_definitions)
150+
151+
152+
@hydra.main(version_base=None, config_path="../conf", config_name="config")
153+
def main(config: DictConfig) -> None:
154+
"""
155+
The main entry point for the plotting pipeline.
156+
157+
Args:
158+
config (DictConfig):
159+
The pipeline configuration.
160+
"""
161+
# Constants
162+
DATA_DIR = config["dir"]["data_dir"]
163+
PROCESS_DIR = config["dir"]["processed_dir"]
164+
OUT_DIR = config["dir"]["out_dir"]
165+
166+
GRN_NAME = config["grn"]["input_dir"]
167+
FEATURE_MATRIX_FILE = config["grn"]["feature_matrix"]
168+
EDGE_LIST_FILE = config["grn"]["edge_list"]
169+
FROM_COL = config["grn"]["from_col"]
170+
TO_COL = config["grn"]["to_col"]
171+
172+
DB_HOST = config["db"]["host"]
173+
DB_NAME = config["db"]["name"]
174+
DB_USERNAME = config["db"]["username"]
175+
DB_PASSWORD = config["db"]["password"]
176+
177+
input_dir = join_path(DATA_DIR, OUT_DIR, GRN_NAME, PROCESS_DIR)
178+
feature_matrix = pd.read_csv(join_path(input_dir, FEATURE_MATRIX_FILE))
179+
edge_list = pd.read_csv(join_path(input_dir, EDGE_LIST_FILE))
180+
181+
G = get_graph(feature_matrix, edge_list, FROM_COL, TO_COL)
182+
to_db(DB_HOST, DB_NAME, DB_USERNAME, DB_PASSWORD, GRN_NAME, G)
183+
184+
185+
if __name__ == "__main__":
186+
main()

conf/config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
defaults:
22
- _self_
33
- grn: in_silico
4+
- db: graph
45
- experiment_tracking: docker
56

67
experiment:
@@ -9,4 +10,5 @@ experiment:
910
dir:
1011
data_dir: data
1112
preprocessed_dir: preprocessed
13+
processed_dir: processed
1214
out_dir: out

conf/db/graph.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
host: http://arangodb:8529
2+
name: grn
3+
username: root
4+
password: password

docker-compose.yml

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ services:
1616
AWS_ACCESS_KEY_ID: $AWS_ACCESS_KEY_ID
1717
AWS_SECRET_ACCESS_KEY: $AWS_SECRET_ACCESS_KEY
1818
MLFLOW_S3_ENDPOINT_URL: $MLFLOW_S3_ENDPOINT_URL
19+
ARANGO_ROOT_PASSWORD: $ARANGO_ROOT_PASSWORD
1920
volumes:
2021
- ${PWD}:${PWD}:Z
2122
- /var/run/docker.sock:/var/run/docker.sock

docs/source/pipelines/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ Nextflow Graph Machine Learning Pipelines
66
:caption: Contents:
77

88
eda.rst
9+
process.rst
10+
to_db.rst

docs/source/pipelines/process.rst

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Process Data
2+
=================================================
3+
4+
*Date published:* |today|
5+
6+
.. automodule:: bin.process
7+
:members:

docs/source/pipelines/to_db.rst

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
To Graph Database
2+
=================================================
3+
4+
*Date published:* |today|
5+
6+
.. automodule:: bin.to_db
7+
:members:

0 commit comments

Comments
 (0)