Skip to content

Commit e6ffd44

Browse files
authored
Merge pull request #189 from neo4j/from_gds-hetero
Allow entities with different property sets in `from_gds` loader
2 parents 7e9a2b7 + b696f77 commit e6ffd44

File tree

6 files changed

+112
-33
lines changed

6 files changed

+112
-33
lines changed

changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
## Bug fixes
1515

16-
* Make sure that temporary internal node properties are not included in the visualization output.
16+
* Make sure that temporary internal node properties are not included in the visualization output
17+
* Fixed bug where loading a graph with `from_gds` where all node or relationship properties are not present on every entity would result in an error
1718

1819

1920
## Improvements

python-wrapper/src/neo4j_viz/gds.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515

1616
def _fetch_node_dfs(
17-
gds: GraphDataScience, G: Graph, node_properties: list[str], node_labels: list[str]
17+
gds: GraphDataScience, G: Graph, node_properties_by_label: dict[str, list[str]], node_labels: list[str]
1818
) -> dict[str, pd.DataFrame]:
1919
return {
2020
lbl: gds.graph.nodeProperties.stream(
21-
G, node_properties=node_properties, node_labels=[lbl], separate_property_columns=True
21+
G, node_properties=node_properties_by_label[lbl], node_labels=[lbl], separate_property_columns=True
2222
)
2323
for lbl in node_labels
2424
}
@@ -79,24 +79,31 @@ def from_gds(
7979
"""
8080
node_properties_from_gds = G.node_properties()
8181
assert isinstance(node_properties_from_gds, pd.Series)
82-
actual_node_properties = list(chain.from_iterable(node_properties_from_gds.to_dict().values()))
82+
actual_node_properties = node_properties_from_gds.to_dict()
83+
all_actual_node_properties = list(chain.from_iterable(actual_node_properties.values()))
8384

84-
if size_property is not None and size_property not in actual_node_properties:
85-
raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'")
85+
if size_property is not None:
86+
if size_property not in all_actual_node_properties:
87+
raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'")
8688

8789
if additional_node_properties is None:
88-
additional_node_properties = actual_node_properties
90+
node_properties_by_label = {k: set(v) for k, v in actual_node_properties.items()}
8991
else:
9092
for prop in additional_node_properties:
91-
if prop not in actual_node_properties:
93+
if prop not in all_actual_node_properties:
9294
raise ValueError(f"There is no node property '{prop}' in graph '{G.name()}'")
9395

94-
node_properties = set()
95-
if additional_node_properties is not None:
96-
node_properties.update(additional_node_properties)
96+
node_properties_by_label = {}
97+
for label, props in actual_node_properties.items():
98+
node_properties_by_label[label] = {
99+
prop for prop in actual_node_properties[label] if prop in additional_node_properties
100+
}
101+
97102
if size_property is not None:
98-
node_properties.add(size_property)
99-
node_properties = list(node_properties)
103+
for label, props in node_properties_by_label.items():
104+
props.add(size_property)
105+
106+
node_properties_by_label = {k: list(v) for k, v in node_properties_by_label.items()}
100107

101108
node_count = G.node_count()
102109
if node_count > max_node_count:
@@ -112,13 +119,14 @@ def from_gds(
112119
property_name = None
113120
try:
114121
# Since GDS does not allow us to only fetch node IDs, we add the degree property
115-
# as a temporary property to ensure that we have at least one property to fetch
116-
if len(actual_node_properties) == 0:
122+
# as a temporary property to ensure that we have at least one property for each label to fetch
123+
if sum([len(props) == 0 for props in node_properties_by_label.values()]) > 0:
117124
property_name = f"neo4j-viz_property_{uuid4()}"
118125
gds.degree.mutate(G_fetched, mutateProperty=property_name)
119-
node_properties = [property_name]
126+
for props in node_properties_by_label.values():
127+
props.append(property_name)
120128

121-
node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties, G_fetched.node_labels())
129+
node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties_by_label, G_fetched.node_labels())
122130
if property_name is not None:
123131
for df in node_dfs.values():
124132
df.drop(columns=[property_name], inplace=True)
@@ -131,35 +139,35 @@ def from_gds(
131139
gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name])
132140

133141
for df in node_dfs.values():
134-
df.rename(columns={"nodeId": "id"}, inplace=True)
135142
if property_name is not None and property_name in df.columns:
136143
df.drop(columns=[property_name], inplace=True)
137-
rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True)
138144

139145
node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates()
140146
if size_property is not None:
141-
if "size" in actual_node_properties and size_property != "size":
147+
if "size" in all_actual_node_properties and size_property != "size":
142148
node_props_df.rename(columns={"size": "__size"}, inplace=True)
143149
node_props_df.rename(columns={size_property: "size"}, inplace=True)
144150

145151
for lbl, df in node_dfs.items():
146-
if "labels" in actual_node_properties:
152+
if "labels" in all_actual_node_properties:
147153
df.rename(columns={"labels": "__labels"}, inplace=True)
148154
df["labels"] = lbl
149155

150-
node_labels_df = pd.concat([df[["id", "labels"]] for df in node_dfs.values()], ignore_index=True, axis=0)
151-
node_labels_df = node_labels_df.groupby("id").agg({"labels": list})
156+
node_labels_df = pd.concat([df[["nodeId", "labels"]] for df in node_dfs.values()], ignore_index=True, axis=0)
157+
node_labels_df = node_labels_df.groupby("nodeId").agg({"labels": list})
152158

153-
node_df = node_props_df.merge(node_labels_df, on="id")
159+
node_df = node_props_df.merge(node_labels_df, on="nodeId")
154160

155-
if "caption" not in actual_node_properties:
161+
if "caption" not in all_actual_node_properties:
156162
node_df["caption"] = node_df["labels"].astype(str)
157163

158164
if "caption" not in rel_df.columns:
159165
rel_df["caption"] = rel_df["relationshipType"]
160166

161167
try:
162-
return _from_dfs(node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"})
168+
return _from_dfs(
169+
node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}, dropna=True
170+
)
163171
except ValueError as e:
164172
err_msg = str(e)
165173
if "column" in err_msg:

python-wrapper/src/neo4j_viz/pandas.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ def _from_dfs(
3131
rel_dfs: DFS_TYPE,
3232
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
3333
rename_properties: Optional[dict[str, str]] = None,
34+
dropna: bool = False,
3435
) -> VisualizationGraph:
35-
relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties)
36+
relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties, dropna=dropna)
3637

3738
if node_dfs is None:
3839
has_size = False
@@ -42,7 +43,7 @@ def _from_dfs(
4243
node_ids.add(rel.target)
4344
nodes = [Node(id=id) for id in node_ids]
4445
else:
45-
nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties)
46+
nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties, dropna=dropna)
4647

4748
VG = VisualizationGraph(nodes=nodes, relationships=relationships)
4849

@@ -52,7 +53,9 @@ def _from_dfs(
5253
return VG
5354

5455

55-
def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> tuple[list[Node], bool]:
56+
def _parse_nodes(
57+
node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False
58+
) -> tuple[list[Node], bool]:
5659
if isinstance(node_dfs, DataFrame):
5760
node_dfs_iter: Iterable[DataFrame] = [node_dfs]
5861
elif node_dfs is None:
@@ -67,6 +70,8 @@ def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]
6770
for node_df in node_dfs_iter:
6871
has_size &= "size" in node_df.columns
6972
for _, row in node_df.iterrows():
73+
if dropna:
74+
row = row.dropna(inplace=False)
7075
top_level = {}
7176
properties = {}
7277
for key, value in row.to_dict().items():
@@ -85,7 +90,9 @@ def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]
8590
return nodes, has_size
8691

8792

88-
def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> list[Relationship]:
93+
def _parse_relationships(
94+
rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False
95+
) -> list[Relationship]:
8996
all_rel_field_aliases = Relationship.all_validation_aliases()
9097

9198
if isinstance(rel_dfs, DataFrame):
@@ -96,6 +103,8 @@ def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str
96103

97104
for rel_df in rel_dfs_iter:
98105
for _, row in rel_df.iterrows():
106+
if dropna:
107+
row = row.dropna(inplace=False)
99108
top_level = {}
100109
properties = {}
101110
for key, value in row.to_dict().items():
@@ -138,4 +147,4 @@ def from_dfs(
138147
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
139148
"""
140149

141-
return _from_dfs(node_dfs, rel_dfs, node_radius_min_max)
150+
return _from_dfs(node_dfs, rel_dfs, node_radius_min_max, dropna=False)

python-wrapper/tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def aura_ds_instance() -> Generator[Any, None, None]:
4343

4444
# setting as environment variables to run notebooks with this connection
4545
os.environ["NEO4J_URI"] = dbms_connection_info.uri
46+
assert isinstance(dbms_connection_info.username, str)
4647
os.environ["NEO4J_USER"] = dbms_connection_info.username
48+
assert isinstance(dbms_connection_info.password, str)
4749
os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password
4850
yield dbms_connection_info
4951

python-wrapper/tests/gds_helper.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ def create_aurads_instance(api: AuraApi) -> tuple[str, DbmsConnectionInfo]:
6262
if wait_result.error:
6363
raise Exception(f"Error while waiting for instance to be running: {wait_result.error}")
6464

65-
wait_result.connection_url
66-
6765
return instance_details.id, DbmsConnectionInfo(
6866
uri=wait_result.connection_url,
6967
username="neo4j",

python-wrapper/tests/test_gds.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,64 @@ def test_from_gds_sample(gds: Any) -> None:
283283
assert len(VG.nodes) <= 10_500
284284
assert len(VG.relationships) >= 9_500
285285
assert len(VG.relationships) <= 10_500
286+
287+
288+
@pytest.mark.requires_neo4j_and_gds
289+
def test_from_gds_hetero(gds: Any) -> None:
290+
from neo4j_viz.gds import from_gds
291+
292+
A_nodes = pd.DataFrame(
293+
{
294+
"nodeId": [0, 1],
295+
"labels": ["A", "A"],
296+
"component": [1, 2],
297+
}
298+
)
299+
B_nodes = pd.DataFrame(
300+
{
301+
"nodeId": [2, 3],
302+
"labels": ["B", "B"],
303+
# No 'component' property
304+
}
305+
)
306+
rels = pd.DataFrame(
307+
{
308+
"sourceNodeId": [0, 1],
309+
"targetNodeId": [2, 3],
310+
"weight": [0.5, 1.5],
311+
"relationshipType": ["REL", "REL2"],
312+
}
313+
)
314+
315+
with gds.graph.construct("flo", [A_nodes, B_nodes], rels) as G:
316+
VG = from_gds(
317+
gds,
318+
G,
319+
)
320+
321+
assert len(VG.nodes) == 4
322+
assert sorted(VG.nodes, key=lambda x: x.id) == [
323+
Node(id=0, caption="['A']", properties=dict(labels=["A"], component=float(1))),
324+
Node(id=1, caption="['A']", properties=dict(labels=["A"], component=float(2))),
325+
Node(id=2, caption="['B']", properties=dict(labels=["B"])),
326+
Node(id=3, caption="['B']", properties=dict(labels=["B"])),
327+
]
328+
329+
assert len(VG.relationships) == 2
330+
vg_rels = sorted(
331+
[
332+
(
333+
e.source,
334+
e.target,
335+
e.caption,
336+
e.properties["relationshipType"],
337+
e.properties["weight"],
338+
)
339+
for e in VG.relationships
340+
],
341+
key=lambda x: x[0],
342+
)
343+
assert vg_rels == [
344+
(0, 2, "REL", "REL", 0.5),
345+
(1, 3, "REL2", "REL2", 1.5),
346+
]

0 commit comments

Comments
 (0)