77
88import static java .util .concurrent .TimeUnit .SECONDS ;
99import static org .opensearch .core .xcontent .XContentParserUtils .ensureExpectedToken ;
10- import static org .opensearch .ml .common .CommonValue .stopWordsIndices ;
1110import static org .opensearch .ml .common .utils .StringUtils .gson ;
1211
1312import java .io .IOException ;
2524import java .util .stream .Collectors ;
2625
2726import org .opensearch .action .LatchedActionListener ;
28- import org .opensearch .action .search .SearchRequest ;
2927import org .opensearch .action .search .SearchResponse ;
3028import org .opensearch .common .util .concurrent .ThreadContext ;
3129import org .opensearch .common .xcontent .LoggingDeprecationHandler ;
3634import org .opensearch .core .xcontent .NamedXContentRegistry ;
3735import org .opensearch .core .xcontent .XContentBuilder ;
3836import org .opensearch .core .xcontent .XContentParser ;
37+ import org .opensearch .remote .metadata .client .SdkClient ;
38+ import org .opensearch .remote .metadata .client .SearchDataObjectRequest ;
39+ import org .opensearch .remote .metadata .common .SdkClientUtils ;
3940import org .opensearch .search .builder .SearchSourceBuilder ;
4041import org .opensearch .transport .client .Client ;
4142
@@ -58,6 +59,8 @@ public class LocalRegexGuardrail extends Guardrail {
5859 private Map <String , List <String >> stopWordsIndicesInput ;
5960 private NamedXContentRegistry xContentRegistry ;
6061 private Client client ;
62+ private SdkClient sdkClient ;
63+ private String tenantId ;
6164
6265 @ Builder (toBuilder = true )
6366 public LocalRegexGuardrail (List <StopWords > stopWords , String [] regex ) {
@@ -109,9 +112,11 @@ public Boolean validate(String input, Map<String, String> parameters) {
109112 }
110113
111114 @ Override
112- public void init (NamedXContentRegistry xContentRegistry , Client client ) {
115+ public void init (NamedXContentRegistry xContentRegistry , Client client , SdkClient sdkClient , String tenantId ) {
113116 this .xContentRegistry = xContentRegistry ;
114117 this .client = client ;
118+ this .sdkClient = sdkClient ;
119+ this .tenantId = tenantId ;
115120 init ();
116121 }
117122
@@ -211,55 +216,34 @@ public Boolean validateStopWords(String input, Map<String, List<String>> stopWor
211216 * @return true if no stop words matching, otherwise false.
212217 */
213218 public Boolean validateStopWordsSingleIndex (String input , String indexName , List <String > fieldNames ) {
214- SearchRequest searchRequest ;
215- AtomicBoolean hitStopWords = new AtomicBoolean (false );
219+ AtomicBoolean passedStopWordCheck = new AtomicBoolean (false );
216220 String queryBody ;
217221 Map <String , String > documentMap = new HashMap <>();
218222 for (String field : fieldNames ) {
219223 documentMap .put (field , input );
220224 }
221225 Map <String , Object > queryBodyMap = Map .of ("query" , Map .of ("percolate" , Map .of ("field" , "query" , "document" , documentMap )));
222226 CountDownLatch latch = new CountDownLatch (1 );
223- ThreadContext .StoredContext context = null ;
224-
225227 try {
226228 queryBody = AccessController .doPrivileged ((PrivilegedExceptionAction <String >) () -> gson .toJson (queryBodyMap ));
227- SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
228- XContentParser queryParser = XContentType .JSON
229- .xContent ()
230- .createParser (xContentRegistry , LoggingDeprecationHandler .INSTANCE , queryBody );
231- searchSourceBuilder .parseXContent (queryParser );
232- searchSourceBuilder .size (1 ); // Only need 1 doc returned, if hit.
233- searchRequest = new SearchRequest ().source (searchSourceBuilder ).indices (indexName );
234- if (isStopWordsSystemIndex (indexName )) {
235- context = client .threadPool ().getThreadContext ().stashContext ();
236- ThreadContext .StoredContext finalContext = context ;
237- client .search (searchRequest , ActionListener .runBefore (new LatchedActionListener (ActionListener .<SearchResponse >wrap (r -> {
238- if (r == null || r .getHits () == null || r .getHits ().getTotalHits () == null || r .getHits ().getTotalHits ().value () == 0 ) {
239- hitStopWords .set (true );
240- }
241- }, e -> {
242- log .error ("Failed to search stop words index {}" , indexName , e );
243- hitStopWords .set (true );
244- }), latch ), () -> finalContext .restore ()));
245- } else {
246- client .search (searchRequest , new LatchedActionListener (ActionListener .<SearchResponse >wrap (r -> {
247- if (r == null || r .getHits () == null || r .getHits ().getTotalHits () == null || r .getHits ().getTotalHits ().value () == 0 ) {
248- hitStopWords .set (true );
249- }
250- }, e -> {
251- log .error ("Failed to search stop words index {}" , indexName , e );
252- hitStopWords .set (true );
253- }), latch ));
229+ SearchDataObjectRequest searchDataObjectRequest = buildSearchDataObjectRequest (indexName , queryBody );
230+ var responseListener = new LatchedActionListener <>(ActionListener .<SearchResponse >wrap (r -> {
231+ if (r == null || r .getHits () == null || r .getHits ().getTotalHits () == null || r .getHits ().getTotalHits ().value () == 0 ) {
232+ passedStopWordCheck .set (true );
233+ }
234+ }, e -> {
235+ log .error ("Failed to search stop words index {}" , indexName , e );
236+ passedStopWordCheck .set (true );
237+ }), latch );
238+ try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
239+ sdkClient
240+ .searchDataObjectAsync (searchDataObjectRequest )
241+ .whenComplete (SdkClientUtils .wrapSearchCompletion (ActionListener .runBefore (responseListener , context ::restore )));
254242 }
255243 } catch (Exception e ) {
256244 log .error ("[validateStopWords] Searching stop words index failed." , e );
257245 latch .countDown ();
258- hitStopWords .set (true );
259- } finally {
260- if (context != null ) {
261- context .close ();
262- }
246+ passedStopWordCheck .set (true );
263247 }
264248
265249 try {
@@ -268,10 +252,17 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
268252 log .error ("[validateStopWords] Searching stop words index was timeout." , e );
269253 throw new IllegalStateException (e );
270254 }
271- return hitStopWords .get ();
255+ return passedStopWordCheck .get ();
272256 }
273257
274- private boolean isStopWordsSystemIndex (String index ) {
275- return stopWordsIndices .contains (index );
258+ protected SearchDataObjectRequest buildSearchDataObjectRequest (String indexName , String queryBody ) throws IOException {
259+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
260+ XContentParser queryParser = XContentType .JSON
261+ .xContent ()
262+ .createParser (xContentRegistry , LoggingDeprecationHandler .INSTANCE , queryBody );
263+ searchSourceBuilder .parseXContent (queryParser );
264+ searchSourceBuilder .size (1 ); // Only need 1 doc returned, if hit.
265+
266+ return SearchDataObjectRequest .builder ().indices (indexName ).searchSourceBuilder (searchSourceBuilder ).tenantId (tenantId ).build ();
276267 }
277268}
0 commit comments