16
16
DEFAULT_RECORD_PROPERTIES ,
17
17
)
18
18
from sycamore .utils .import_utils import requires_modules
19
+ from sycamore .plan_nodes import Node
20
+ from sycamore .docset import DocSet
21
+ from sycamore .context import Context
19
22
20
23
if typing .TYPE_CHECKING :
21
24
from opensearchpy import OpenSearch
@@ -42,6 +45,7 @@ class OpenSearchWriterClientParams(BaseDBWriter.ClientParams):
42
45
@dataclass
43
46
class OpenSearchWriterTargetParams (BaseDBWriter .TargetParams ):
44
47
index_name : str
48
+ _doc_count : int = 0
45
49
settings : dict [str , Any ] = field (default_factory = lambda : {"index.knn" : True })
46
50
mappings : dict [str , Any ] = field (
47
51
default_factory = lambda : {
@@ -92,6 +96,61 @@ def compatible_with(self, other: "BaseDBWriter.TargetParams") -> bool:
92
96
other_flat_mappings = dict (flatten_data (other .mappings ))
93
97
return check_dictionary_compatibility (my_flat_mappings , other_flat_mappings )
94
98
99
+ @classmethod
100
+ def from_write_args (
101
+ cls ,
102
+ index_name : str ,
103
+ plan : Node ,
104
+ context : Context ,
105
+ reliability_rewriter : bool ,
106
+ execute : bool ,
107
+ insert_settings : Optional [dict ] = None ,
108
+ index_settings : Optional [dict ] = None ,
109
+ ) -> "OpenSearchWriterTargetParams" :
110
+ """
111
+ Build OpenSearchWriterTargetParams from write operation arguments.
112
+
113
+ Args:
114
+ index_name: Name of the OpenSearch index
115
+ plan: The execution plan Node
116
+ context: The execution Context
117
+ reliability_rewriter: Whether to enable reliability rewriter mode
118
+ execute: Whether to execute the pipeline immediately
119
+ insert_settings: Optional settings for data insertion
120
+ index_settings: Optional index configuration settings
121
+
122
+ Returns:
123
+ OpenSearchWriterTargetParams configured with the provided settings
124
+
125
+ Raises:
126
+ AssertionError: If reliability_rewriter conditions are not met
127
+ """
128
+ target_params_dict : dict [str , Any ] = {
129
+ "index_name" : index_name ,
130
+ "_doc_count" : 0 ,
131
+ }
132
+
133
+ if reliability_rewriter :
134
+ from sycamore .materialize import Materialize
135
+
136
+ assert execute , "Reliability rewriter requires execute to be True"
137
+ assert (
138
+ type (plan ) == Materialize
139
+ ), "The first node must be a materialize node for reliability rewriter to work"
140
+ assert not plan .children [
141
+ 0
142
+ ], "Pipeline should only have read materialize and write nodes for reliability rewriter to work"
143
+ target_params_dict ["_doc_count" ] = DocSet (context , plan ).count ()
144
+
145
+ if insert_settings :
146
+ target_params_dict ["insert_settings" ] = insert_settings
147
+
148
+ if index_settings :
149
+ target_params_dict ["settings" ] = index_settings .get ("body" , {}).get ("settings" , {})
150
+ target_params_dict ["mappings" ] = index_settings .get ("body" , {}).get ("mappings" , {})
151
+
152
+ return cls (** target_params_dict )
153
+
95
154
96
155
class OpenSearchWriterClient (BaseDBWriter .Client ):
97
156
def __init__ (self , os_client : "OpenSearch" ):
@@ -187,6 +246,8 @@ def _string_values_to_python_types(obj: Any):
187
246
return obj
188
247
return obj
189
248
249
+ # TODO: Convert OpenSearchWriterTargetParams to pydantic model
250
+
190
251
assert isinstance (
191
252
target_params , OpenSearchWriterTargetParams
192
253
), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n { target_params } "
@@ -196,7 +257,27 @@ def _string_values_to_python_types(obj: Any):
196
257
assert isinstance (mappings , dict )
197
258
settings = _string_values_to_python_types (response .get (index_name , {}).get ("settings" , {}))
198
259
assert isinstance (settings , dict )
199
- return OpenSearchWriterTargetParams (index_name = index_name , mappings = mappings , settings = settings )
260
+ _doc_count = target_params ._doc_count
261
+ assert isinstance (_doc_count , int )
262
+ return OpenSearchWriterTargetParams (
263
+ index_name = index_name ,
264
+ mappings = mappings ,
265
+ settings = settings ,
266
+ _doc_count = _doc_count ,
267
+ )
268
+
269
+ def reliability_assertor (self , target_params : BaseDBWriter .TargetParams ):
270
+ assert isinstance (
271
+ target_params , OpenSearchWriterTargetParams
272
+ ), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n { target_params } "
273
+ log .info ("Flushing index..." )
274
+ self ._client .indices .flush (index = target_params .index_name , params = {"timeout" : 300 })
275
+ log .info ("Done flushing index." )
276
+ indices = self ._client .cat .indices (index = target_params .index_name , format = "json" )
277
+ assert len (indices ) == 1 , f"Expected 1 index, found { len (indices )} "
278
+ num_docs = int (indices [0 ]["docs.count" ])
279
+ log .info (f"{ num_docs } chunks written in index { target_params .index_name } " )
280
+ assert num_docs == target_params ._doc_count , f"Expected { target_params ._doc_count } docs, found { num_docs } "
200
281
201
282
202
283
@dataclass
0 commit comments