Skip to content

Commit e99ebb0

Browse files
authored
Expose context from the orchestrator to the components (#301)
* Expose context from the orchestrator to the components * Rename to match component name + ruff * Add tests * Update changelog and example readme * unused import * mypy * Changes so that the `run` method is not required anymore * Update documentation * We still need to raise error if components do not have at least one of the two methods implemented * Update CHANGELOG.md * Improve documentation of future changes * Undo make notifier private, not needed
1 parent 92a176f commit e99ebb0

32 files changed

+572
-211
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22

33
## Next
44

5+
### Added
6+
7+
- Added the `run_with_context` method to `Component`. This method includes a `context_` parameter, which provides information about the pipeline from which the component is executed (e.g., the `run_id`). It also enables the component to send events to the pipeline's callback function.
8+
9+
510
## 1.6.0
611

712
### Added
813

914
- Added optional schema enforcement as a validation layer after entity and relation extraction.
1015
- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter.
1116
- Introduced SearchQueryParseError for handling invalid Lucene query strings in HybridRetriever and HybridCypherRetriever.
17+
- Components can now be called with the `run_with_context` method that gets an extra `context_` argument containing information about the pipeline it's run from: the `run_id`, `task_name` and a `notify` function that can be used to send `TASK_PROGRESS` events to the same callback as the pipeline events.
1218

1319
### Fixed
1420

docs/source/api.rst

+6-2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@ API Documentation
99
Components
1010
**********
1111

12+
Component
13+
=========
14+
15+
.. autoclass:: neo4j_graphrag.experimental.pipeline.component.Component
16+
:members: run, run_with_context
17+
1218
DataLoader
1319
==========
1420

1521
.. autoclass:: neo4j_graphrag.experimental.components.pdf_loader.DataLoader
1622
:members: run, get_document_metadata
1723

18-
1924
PdfLoader
2025
=========
2126

@@ -59,7 +64,6 @@ LexicalGraphBuilder
5964
:members:
6065
:exclude-members: component_inputs, component_outputs
6166

62-
6367
Neo4jChunkReader
6468
================
6569

docs/source/types.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -158,22 +158,22 @@ ParamFromEnvConfig
158158
EventType
159159
=========
160160

161-
.. autoenum:: neo4j_graphrag.experimental.pipeline.types.EventType
161+
.. autoenum:: neo4j_graphrag.experimental.pipeline.notification.EventType
162162

163163

164164
PipelineEvent
165165
==============
166166

167-
.. autoclass:: neo4j_graphrag.experimental.pipeline.types.PipelineEvent
167+
.. autoclass:: neo4j_graphrag.experimental.pipeline.notification.PipelineEvent
168168

169169
TaskEvent
170170
==============
171171

172-
.. autoclass:: neo4j_graphrag.experimental.pipeline.types.TaskEvent
172+
.. autoclass:: neo4j_graphrag.experimental.pipeline.notification.TaskEvent
173173

174174

175175
EventCallbackProtocol
176176
=====================
177177

178-
.. autoclass:: neo4j_graphrag.experimental.pipeline.types.EventCallbackProtocol
178+
.. autoclass:: neo4j_graphrag.experimental.pipeline.notification.EventCallbackProtocol
179179
:members: __call__

docs/source/user_guide_pipeline.rst

+34-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ their own by following these steps:
2222

2323
1. Create a subclass of the Pydantic `neo4j_graphrag.experimental.pipeline.DataModel` to represent the data being returned by the component
2424
2. Create a subclass of `neo4j_graphrag.experimental.pipeline.Component`
25-
3. Create a run method in this new class and specify the required inputs and output model using the just created `DataModel`
25+
3. Create a `run_with_context` method in this new class and specify the required inputs and output model using the just created `DataModel`
2626
4. Implement the run method: it's an `async` method, allowing tasks to be parallelized and awaited within this method.
2727

2828
An example is given below, where a `ComponentAdd` is created to add two numbers together and return
@@ -31,12 +31,13 @@ the resulting sum:
3131
.. code:: python
3232
3333
from neo4j_graphrag.experimental.pipeline import Component, DataModel
34+
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
3435
3536
class IntResultModel(DataModel):
3637
result: int
3738
3839
class ComponentAdd(Component):
39-
async def run(self, number1: int, number2: int = 1) -> IntResultModel:
40+
async def run_with_context(self, context_: RunContext, number1: int, number2: int = 1) -> IntResultModel:
4041
return IntResultModel(result = number1 + number2)
4142
4243
Read more about :ref:`components-section` in the API Documentation.
@@ -141,6 +142,7 @@ It is possible to add a callback to receive notification about pipeline progress
141142
- `PIPELINE_STARTED`, when pipeline starts
142143
- `PIPELINE_FINISHED`, when pipeline ends
143144
- `TASK_STARTED`, when a task starts
145+
- `TASK_PROGRESS`, sent by each component (depends on component's implementation, see below)
144146
- `TASK_FINISHED`, when a task ends
145147

146148

@@ -172,3 +174,33 @@ See :ref:`pipelineevent` and :ref:`taskevent` to see what is sent in each event
172174
# ... add components, connect them as usual
173175
174176
await pipeline.run(...)
177+
178+
179+
Send Events from Components
180+
===========================
181+
182+
Components can send progress notifications using the `notify` function from
183+
`context_` by implementing the `run_from_context` method:
184+
185+
.. code:: python
186+
187+
from neo4j_graphrag.experimental.pipeline import Component, DataModel
188+
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
189+
190+
class IntResultModel(DataModel):
191+
result: int
192+
193+
class ComponentAdd(Component):
194+
async def run_with_context(self, context_: RunContext, number1: int, number2: int = 1) -> IntResultModel:
195+
for fake_iteration in range(10):
196+
await context_.notify(
197+
message=f"Starting iteration {fake_iteration} out of 10",
198+
data={"iteration": fake_iteration, "total": 10}
199+
)
200+
return IntResultModel(result = number1 + number2)
201+
202+
This will send an `TASK_PROGRESS` event to the pipeline callback.
203+
204+
.. note::
205+
206+
In a future release, the `context_` parameter will be added to the `run` method.

examples/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ are listed in [the last section of this file](#customize).
103103
- [Export lexical graph creation into another pipeline](./customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py)
104104
- [Build pipeline from config file](customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py)
105105
- [Add event listener to get notification about Pipeline progress](./customize/build_graph/pipeline/pipeline_with_notifications.py)
106+
- [Use component context to send notifications about Component progress](./customize/build_graph/pipeline/pipeline_with_component_notifications.py)
106107

107108

108109
#### Components

examples/build_graph/simple_kg_builder_from_text.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from neo4j_graphrag.embeddings import OpenAIEmbeddings
1515
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
1616
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
17-
from neo4j_graphrag.experimental.pipeline.types import (
17+
from neo4j_graphrag.experimental.pipeline.types.schema import (
1818
EntityInputType,
1919
RelationInputType,
2020
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""This example demonstrates how to use event callback to receive notifications
2+
about the component progress.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import asyncio
8+
import logging
9+
from typing import Any
10+
11+
from neo4j_graphrag.experimental.pipeline import Pipeline, Component, DataModel
12+
from neo4j_graphrag.experimental.pipeline.notification import Event, EventType
13+
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
14+
15+
logger = logging.getLogger(__name__)
16+
logging.basicConfig()
17+
logger.setLevel(logging.INFO)
18+
19+
20+
class MultiplyComponentResult(DataModel):
21+
result: list[int]
22+
23+
24+
class MultiplicationComponent(Component):
25+
def __init__(self, f: int) -> None:
26+
self.f = f
27+
28+
async def multiply_number(
29+
self,
30+
context_: RunContext,
31+
number: int,
32+
) -> int:
33+
await context_.notify(
34+
message=f"Processing number {number}",
35+
data={"number_processed": number},
36+
)
37+
return self.f * number
38+
39+
# implementing `run_with_context` to get access to
40+
# the pipeline's RunContext:
41+
async def run_with_context(
42+
self,
43+
context_: RunContext,
44+
numbers: list[int],
45+
**kwargs: Any,
46+
) -> MultiplyComponentResult:
47+
result = await asyncio.gather(
48+
*[
49+
self.multiply_number(
50+
context_,
51+
number,
52+
)
53+
for number in numbers
54+
]
55+
)
56+
return MultiplyComponentResult(result=result)
57+
58+
59+
async def event_handler(event: Event) -> None:
60+
"""Function can do anything about the event,
61+
here we're just logging it if it's a pipeline-level event.
62+
"""
63+
if event.event_type == EventType.TASK_PROGRESS:
64+
logger.warning(event)
65+
else:
66+
logger.info(event)
67+
68+
69+
async def main() -> None:
70+
""" """
71+
pipe = Pipeline(
72+
callback=event_handler,
73+
)
74+
# define the components
75+
pipe.add_component(
76+
MultiplicationComponent(f=2),
77+
"multiply_by_2",
78+
)
79+
pipe.add_component(
80+
MultiplicationComponent(f=10),
81+
"multiply_by_10",
82+
)
83+
# define the execution order of component
84+
# and how the output of previous components must be used
85+
pipe.connect(
86+
"multiply_by_2",
87+
"multiply_by_10",
88+
input_config={"numbers": "multiply_by_2.result"},
89+
)
90+
# user input:
91+
pipe_inputs_1 = {
92+
"multiply_by_2": {
93+
"numbers": [1, 2, 5, 4],
94+
},
95+
}
96+
pipe_inputs_2 = {
97+
"multiply_by_2": {
98+
"numbers": [3, 10, 1],
99+
}
100+
}
101+
# run the pipeline
102+
await asyncio.gather(
103+
pipe.run(pipe_inputs_1),
104+
pipe.run(pipe_inputs_2),
105+
)
106+
107+
108+
if __name__ == "__main__":
109+
asyncio.run(main())

examples/customize/build_graph/pipeline/pipeline_with_notifications.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from neo4j_graphrag.experimental.pipeline import Pipeline
1717
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
18-
from neo4j_graphrag.experimental.pipeline.types import Event
18+
from neo4j_graphrag.experimental.pipeline.notification import Event
1919

2020
logger = logging.getLogger(__name__)
2121
logging.basicConfig()

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import abc
1817
import asyncio
1918
import enum
2019
import json
@@ -115,7 +114,7 @@ def fix_invalid_json(raw_json: str) -> str:
115114
return repaired_json
116115

117116

118-
class EntityRelationExtractor(Component, abc.ABC):
117+
class EntityRelationExtractor(Component):
119118
"""Abstract class for entity relation extraction components.
120119
121120
Args:
@@ -133,15 +132,14 @@ def __init__(
133132
self.on_error = on_error
134133
self.create_lexical_graph = create_lexical_graph
135134

136-
@abc.abstractmethod
137135
async def run(
138136
self,
139137
chunks: TextChunks,
140138
document_info: Optional[DocumentInfo] = None,
141139
lexical_graph_config: Optional[LexicalGraphConfig] = None,
142140
**kwargs: Any,
143141
) -> Neo4jGraph:
144-
pass
142+
raise NotImplementedError()
145143

146144
def update_ids(
147145
self,

src/neo4j_graphrag/experimental/components/resolver.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import abc
1615
from typing import Any, Optional
1716

1817
import neo4j
@@ -22,7 +21,7 @@
2221
from neo4j_graphrag.utils import driver_config
2322

2423

25-
class EntityResolver(Component, abc.ABC):
24+
class EntityResolver(Component):
2625
"""Entity resolution base class
2726
2827
Args:
@@ -38,9 +37,8 @@ def __init__(
3837
self.driver = driver_config.override_user_agent(driver)
3938
self.filter_query = filter_query
4039

41-
@abc.abstractmethod
4240
async def run(self, *args: Any, **kwargs: Any) -> ResolutionStats:
43-
pass
41+
raise NotImplementedError()
4442

4543

4644
class SinglePropertyExactMatchResolver(EntityResolver):

src/neo4j_graphrag/experimental/components/schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from neo4j_graphrag.exceptions import SchemaValidationError
2323
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
24-
from neo4j_graphrag.experimental.pipeline.types import (
24+
from neo4j_graphrag.experimental.pipeline.types.schema import (
2525
EntityInputType,
2626
RelationInputType,
2727
)

0 commit comments

Comments
 (0)