Skip to content

Commit 97be372

Browse files
Add Opensearch Writer Reliability (#1130)
* Initial dev * Add integration test, simplify API * Address comments * Fix test bug, address comment * Change node check to BaseDBWriter * BaseDBWriter client signature
1 parent 214fad2 commit 97be372

File tree

4 files changed

+167
-12
lines changed

4 files changed

+167
-12
lines changed

lib/sycamore/sycamore/connectors/base_writer.py

+9
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ def create_target_idempotent(self, target_params: "BaseDBWriter.TargetParams"):
3333
def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams") -> "BaseDBWriter.TargetParams":
3434
pass
3535

36+
def reliability_assertor(self, target_params: "BaseDBWriter.TargetParams"):
37+
"""
38+
Method to verify that all documents were successfully written when reliability mode is enabled.
39+
40+
Raises:
41+
NotImplementedError: If the implementing class doesn't support reliability checks
42+
"""
43+
raise NotImplementedError("This writer does not support reliability checks")
44+
3645
def close(self):
3746
pass
3847

lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py

+82-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
DEFAULT_RECORD_PROPERTIES,
1717
)
1818
from sycamore.utils.import_utils import requires_modules
19+
from sycamore.plan_nodes import Node
20+
from sycamore.docset import DocSet
21+
from sycamore.context import Context
1922

2023
if typing.TYPE_CHECKING:
2124
from opensearchpy import OpenSearch
@@ -42,6 +45,7 @@ class OpenSearchWriterClientParams(BaseDBWriter.ClientParams):
4245
@dataclass
4346
class OpenSearchWriterTargetParams(BaseDBWriter.TargetParams):
4447
index_name: str
48+
_doc_count: int = 0
4549
settings: dict[str, Any] = field(default_factory=lambda: {"index.knn": True})
4650
mappings: dict[str, Any] = field(
4751
default_factory=lambda: {
@@ -92,6 +96,61 @@ def compatible_with(self, other: "BaseDBWriter.TargetParams") -> bool:
9296
other_flat_mappings = dict(flatten_data(other.mappings))
9397
return check_dictionary_compatibility(my_flat_mappings, other_flat_mappings)
9498

99+
@classmethod
100+
def from_write_args(
101+
cls,
102+
index_name: str,
103+
plan: Node,
104+
context: Context,
105+
reliability_rewriter: bool,
106+
execute: bool,
107+
insert_settings: Optional[dict] = None,
108+
index_settings: Optional[dict] = None,
109+
) -> "OpenSearchWriterTargetParams":
110+
"""
111+
Build OpenSearchWriterTargetParams from write operation arguments.
112+
113+
Args:
114+
index_name: Name of the OpenSearch index
115+
plan: The execution plan Node
116+
context: The execution Context
117+
reliability_rewriter: Whether to enable reliability rewriter mode
118+
execute: Whether to execute the pipeline immediately
119+
insert_settings: Optional settings for data insertion
120+
index_settings: Optional index configuration settings
121+
122+
Returns:
123+
OpenSearchWriterTargetParams configured with the provided settings
124+
125+
Raises:
126+
AssertionError: If reliability_rewriter conditions are not met
127+
"""
128+
target_params_dict: dict[str, Any] = {
129+
"index_name": index_name,
130+
"_doc_count": 0,
131+
}
132+
133+
if reliability_rewriter:
134+
from sycamore.materialize import Materialize
135+
136+
assert execute, "Reliability rewriter requires execute to be True"
137+
assert (
138+
type(plan) == Materialize
139+
), "The first node must be a materialize node for reliability rewriter to work"
140+
assert not plan.children[
141+
0
142+
], "Pipeline should only have read materialize and write nodes for reliability rewriter to work"
143+
target_params_dict["_doc_count"] = DocSet(context, plan).count()
144+
145+
if insert_settings:
146+
target_params_dict["insert_settings"] = insert_settings
147+
148+
if index_settings:
149+
target_params_dict["settings"] = index_settings.get("body", {}).get("settings", {})
150+
target_params_dict["mappings"] = index_settings.get("body", {}).get("mappings", {})
151+
152+
return cls(**target_params_dict)
153+
95154

96155
class OpenSearchWriterClient(BaseDBWriter.Client):
97156
def __init__(self, os_client: "OpenSearch"):
@@ -187,6 +246,8 @@ def _string_values_to_python_types(obj: Any):
187246
return obj
188247
return obj
189248

249+
# TODO: Convert OpenSearchWriterTargetParams to pydantic model
250+
190251
assert isinstance(
191252
target_params, OpenSearchWriterTargetParams
192253
), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n{target_params}"
@@ -196,7 +257,27 @@ def _string_values_to_python_types(obj: Any):
196257
assert isinstance(mappings, dict)
197258
settings = _string_values_to_python_types(response.get(index_name, {}).get("settings", {}))
198259
assert isinstance(settings, dict)
199-
return OpenSearchWriterTargetParams(index_name=index_name, mappings=mappings, settings=settings)
260+
_doc_count = target_params._doc_count
261+
assert isinstance(_doc_count, int)
262+
return OpenSearchWriterTargetParams(
263+
index_name=index_name,
264+
mappings=mappings,
265+
settings=settings,
266+
_doc_count=_doc_count,
267+
)
268+
269+
def reliability_assertor(self, target_params: BaseDBWriter.TargetParams):
270+
assert isinstance(
271+
target_params, OpenSearchWriterTargetParams
272+
), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n{target_params}"
273+
log.info("Flushing index...")
274+
self._client.indices.flush(index=target_params.index_name, params={"timeout": 300})
275+
log.info("Done flushing index.")
276+
indices = self._client.cat.indices(index=target_params.index_name, format="json")
277+
assert len(indices) == 1, f"Expected 1 index, found {len(indices)}"
278+
num_docs = int(indices[0]["docs.count"])
279+
log.info(f"{num_docs} chunks written in index {target_params.index_name}")
280+
assert num_docs == target_params._doc_count, f"Expected {target_params._doc_count} docs, found {num_docs}"
200281

201282

202283
@dataclass

lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py

+50
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,56 @@ def remove_reconstruct_doc_property(doc: Document):
248248
# Clean up
249249
os_client.indices.delete(setup_index, ignore_unavailable=True)
250250

251+
def test_write_with_reliability(self, setup_index, os_client, exec_mode):
252+
"""
253+
Validates that when materialized pickle outputs are deleted, the index is rewritten
254+
with the correct (reduced) number of chunks.
255+
"""
256+
with tempfile.TemporaryDirectory() as tmpdir1:
257+
path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf")
258+
context = sycamore.init(exec_mode=exec_mode)
259+
260+
# 2 docs for ray execution
261+
(
262+
context.read.binary([path, path], binary_format="pdf")
263+
.partition(ArynPartitioner(aryn_api_key=ARYN_API_KEY))
264+
.explode()
265+
.materialize(path=tmpdir1)
266+
.execute()
267+
)
268+
269+
(
270+
context.read.materialize(tmpdir1).write.opensearch(
271+
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
272+
index_name=setup_index,
273+
index_settings=TestOpenSearchRead.INDEX_SETTINGS,
274+
reliability_rewriter=True,
275+
)
276+
)
277+
os_client.indices.refresh(setup_index)
278+
count = get_doc_count(os_client, setup_index)
279+
280+
# Delete 1 pickle file to make sure reliability rewriter works
281+
pickle_files = [f for f in os.listdir(tmpdir1) if f.endswith(".pickle")]
282+
assert pickle_files, "No pickle files found in materialized directory"
283+
os.remove(os.path.join(tmpdir1, pickle_files[0]))
284+
285+
# Delete and recreate the index - should have fewer chunks
286+
(
287+
context.read.materialize(tmpdir1).write.opensearch(
288+
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
289+
index_name=setup_index,
290+
index_settings=TestOpenSearchRead.INDEX_SETTINGS,
291+
reliability_rewriter=True,
292+
)
293+
)
294+
os_client.indices.refresh(setup_index)
295+
re_count = get_doc_count(os_client, setup_index)
296+
297+
# Verify document count is reduced
298+
assert count - 1 == re_count, f"Expected {count} documents, found {re_count}"
299+
os_client.indices.delete(setup_index)
300+
251301
def _test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client, cache_dir):
252302
"""
253303
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents.

lib/sycamore/sycamore/writer.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from neo4j import Auth
2323
from neo4j.auth_management import AuthManager
2424

25+
from sycamore.connectors.base_writer import BaseDBWriter
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -47,6 +48,7 @@ def opensearch(
4748
index_settings: dict,
4849
insert_settings: Optional[dict] = None,
4950
execute: bool = True,
51+
reliability_rewriter: bool = False,
5052
**kwargs,
5153
) -> Optional["DocSet"]:
5254
"""Writes the content of the DocSet into the specified OpenSearch index.
@@ -103,9 +105,9 @@ def opensearch(
103105
"""
104106

105107
from sycamore.connectors.opensearch import (
106-
OpenSearchWriter,
107108
OpenSearchWriterClientParams,
108109
OpenSearchWriterTargetParams,
110+
OpenSearchWriter,
109111
)
110112
from typing import Any
111113
import copy
@@ -137,23 +139,30 @@ def _convert_to_host_port_list(hostlist: Any) -> list[HostAndPort]:
137139
os_client_args["hosts"] = _convert_to_host_port_list(hosts)
138140
client_params = OpenSearchWriterClientParams(**os_client_args)
139141

140-
target_params: OpenSearchWriterTargetParams
141-
target_params_dict: dict[str, Any] = {"index_name": index_name}
142-
if insert_settings:
143-
target_params_dict["insert_settings"] = insert_settings
144-
if index_settings:
145-
target_params_dict["settings"] = index_settings.get("body", {}).get("settings", {})
146-
target_params_dict["mappings"] = index_settings.get("body", {}).get("mappings", {})
147-
target_params = OpenSearchWriterTargetParams(**target_params_dict)
142+
target_params = OpenSearchWriterTargetParams.from_write_args(
143+
index_name=index_name,
144+
plan=self.plan,
145+
context=self.context,
146+
reliability_rewriter=reliability_rewriter,
147+
execute=execute,
148+
insert_settings=insert_settings,
149+
index_settings=index_settings,
150+
)
148151
os = OpenSearchWriter(
149152
self.plan, client_params=client_params, target_params=target_params, name="OsrchWrite", **kwargs
150153
)
154+
client = None
155+
if reliability_rewriter:
156+
client = os.Client.from_client_params(client_params)
157+
if client._client.indices.exists(index=index_name):
158+
logger.info(f"\n\nWARNING WARNING WARNING: Deleting existing index {index_name}\n\n")
159+
client._client.indices.delete(index=index_name)
151160

152161
# We will probably want to break this at some point so that write
153162
# doesn't execute automatically, and instead you need to say something
154163
# like docset.write.opensearch().execute(), allowing sensible writes
155164
# to multiple locations and post-write operations.
156-
return self._maybe_execute(os, execute)
165+
return self._maybe_execute(os, execute, client)
157166

158167
@requires_modules(["weaviate", "weaviate.collections.classes.config"], extra="weaviate")
159168
def weaviate(
@@ -849,10 +858,16 @@ def aryn(
849858

850859
return self._maybe_execute(ds, True)
851860

852-
def _maybe_execute(self, node: Node, execute: bool) -> Optional[DocSet]:
861+
def _maybe_execute(
862+
self, node: Node, execute: bool, client: Optional[BaseDBWriter.Client] = None
863+
) -> Optional[DocSet]:
853864
ds = DocSet(self.context, node)
854865
if not execute:
855866
return ds
856867

857868
ds.execute()
869+
870+
if client is not None:
871+
if type(node) == BaseDBWriter:
872+
client.reliability_assertor(node._target_params)
858873
return None

0 commit comments

Comments
 (0)