Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions libs/langchain-mongodb/langchain_mongodb/graphrag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from langchain_mongodb.graphrag.graph import MongoDBGraphStore

__all__ = ["MongoDBGraphStore"]
142 changes: 141 additions & 1 deletion libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from copy import deepcopy
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
Expand All @@ -30,6 +30,9 @@
Entity: TypeAlias = Dict[str, Any]
"""Represents an Entity in the knowledge graph with specific schema. See .schema"""

import holoviews # type: ignore[import-untyped]
import networkx

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -544,3 +547,140 @@ def chat_response(
entity_schema=entity_schema,
)
)

def to_networkx(self, **kwargs: Any) -> networkx.DiGraph:
"""Utility converts Entity Collection to `NetworkX DiGraph <https://networkx.org/documentation/stable/index.html>`_

NOTE: Requires optional-dependency "viz", i.e. `uv sync --extra viz`

Args:
**kwargs: Keywork arguments passed to json formatting, and constructors of graph, nodes, and edges.

Returns: networkx.DiGraph
"""

try:
import json

import networkx as nx
except ImportError as e:
raise ImportError(
"Install optional-dependency `viz` for networkx or to view in Holoviews"
) from e

def _safe_get(lst: list, i: int, default: Any = "") -> Any:
return lst[i] if i < len(lst) else default

# First pass: Add all nodes with their attributes
nx_graph = nx.DiGraph(**kwargs)
for doc in self.collection.find({}):
# Add node with all attributes
node_id = doc["_id"]
node_attrs = {}
node_attrs["type"] = json.dumps(doc.get("type", ""), **kwargs)
node_attrs["attributes"] = json.dumps(doc.get("attributes", {}), **kwargs)
nx_graph.add_node(node_id, **node_attrs, **kwargs)

# Second pass: Add edges based on relationships
for doc in self.collection.find({}):
source_id = doc["_id"]
relationships = doc.get("relationships", {})
# relationships can contain numerous target_ids, each with type and attributes
target_ids = relationships.get("target_ids", [])
n_targets = len(target_ids)
types = relationships.get("types", [])
attrs = relationships.get("attributes", [])

for t in range(n_targets):
# Add edge and attributes
edge_attrs = {}
edge_attrs["type"] = json.dumps(_safe_get(types, t), **kwargs)
edge_attrs["attributes"] = json.dumps(_safe_get(attrs, t), **kwargs)
if nx_graph.has_node(target_ids[t]):
nx_graph.add_edge(source_id, target_ids[t], **edge_attrs, **kwargs)
else:
logger.warning(
f"{source_id=} references {target_ids[t]=} not found in collection"
)

return nx_graph

def view(
self,
layout: Optional[Callable] = None,
nx_opts: Optional[dict] = None,
edge_opts: Optional[dict] = None,
node_opts: Optional[dict] = None,
) -> holoviews.Graph:
"""Draws a Knowledge Graph as Holoviews/Bokeh interactive plot.

We first convert the entity collection to a NetworkX Graph,
and then convert it to a Holoviews Graph via their API.

Both of these libraries are incredibly feature rich.
We encourage those interested in visualization and/or graph analysis
to dig deeper into their documentation.
The customization options are truly stunning.

The default layout chosen is the spring_layout.
This maximizes the distance between nodes. As our entities have a type field,
however, another good layout choice is:
`layout=nx.multipartite_layout, nx_opts["subset_key"]= "type"`

NOTE: Requires optional-dependency "viz", i.e. `uv sync --extra viz`

Args:
layout: `networkx layout. <https://networkx.org/documentation/stable/reference/drawing.html#module-networkx.drawing.layout>`_
Defaults to networkx.spring_layout.
nx_opts: Keyword arguments for layout function.
edge_opts: Keyword arguments to draw edges.
node_opts: Keyword arguments to draw nodes.

Returns: `holoviews.Graph <https://holoviews.org/user_guide/Network_Graphs.html>`_

"""
try:
import holoviews as hv
import networkx as nx

hv.extension("bokeh")
except ImportError as e:
raise ImportError("To view graph, install optional-dependency `viz`") from e

hv.opts.defaults(
hv.opts.Graph(xaxis=None, yaxis=None), hv.opts.Nodes(xaxis=None, yaxis=None)
)

if layout is None:
layout = nx.spring_layout
# Convert entity collection to NetworkX graph.
nx_opts = {} if nx_opts is None else nx_opts
nx_graph = self.to_networkx(**nx_opts)
# Convert to HoloViews Graph
hv_graph = hv.Graph.from_networkx(nx_graph, layout, **nx_opts)
# Display with hover tools over edges and nodes
edge_opts = {} if edge_opts is None else edge_opts
node_opts = {} if node_opts is None else node_opts
return hv_graph.opts(
inspection_policy="edges", **edge_opts
) * hv_graph.nodes.opts(**node_opts)


# TODO
# Discuss API => split out json_opts and nx_opts
# Document
# Choose reasonable defaults
# - Currently I have sidestepped them leaving standard defaults.
# - dpi, width, height
# - where do we specify these? hv.opts?
# - Documents layouts. =>
# - spring_layout, (100-1000 nodes)
# - nx.kamada_kawai_layout, (< 100 nodes)
# - # These others require: brew install graphviz. pip install pydot graphviz
# layout=nx.nx_pydot.graphviz_layout, nx_opts={"prog":"sfdp"}) (1000+ nodes) ,
# others: prog="dot" for Hierarchical Knowledge Graphs, prog=".
# ** layout=nx.multipartite_layout, subset_key = "type", # If you have different node types (e.g., Person, Organization, Location)
# lambda x: nx.nx_pydot.graphviz_layout(x, prog="twopi")
# Export collection to play with in notebook ==> view_graphstore.ipynb
# Create a much larger graph to test defaults. => Use Claude Code to build tests.
# Test save. (Then Document it.)
15 changes: 15 additions & 0 deletions libs/langchain-mongodb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ dev = [
"typing-extensions>=4.12.2",
]

[project.optional-dependencies]
viz = [
"networkx",
"holoviews",
"jupyter", # TODO All of the follow provide further functionality and layouts.
# "pandas",
# "numpy",
# "scikit-image",
# "pyparsing",
# "datashader", # TODO - Review inclusion of datashader, used for grouping nodes in very large graphs
# "dask",
# "graphviz", TODO - Review inclusion of graphviz, used for additional plotting layouts
# "pydot",
]

[tool.pytest.ini_options]
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
markers = [
Expand Down
Loading