1+ use std:: ops:: Range ;
2+
13use pg_schema_cache:: SchemaCache ;
4+ use pg_treesitter_queries:: { queries, TreeSitterQueriesExecutor } ;
25
36use crate :: CompletionParams ;
47
@@ -52,10 +55,13 @@ pub(crate) struct CompletionContext<'a> {
5255 pub schema_name : Option < String > ,
5356 pub wrapping_clause_type : Option < ClauseType > ,
5457 pub is_invocation : bool ,
58+ pub wrapping_statement_range : Option < Range < usize > > ,
59+
60+ pub ts_query_executor : Option < TreeSitterQueriesExecutor < ' a > > ,
5561}
5662
5763impl < ' a > CompletionContext < ' a > {
58- pub fn new ( params : & ' a CompletionParams ) -> Self {
64+ pub async fn new ( params : & ' a CompletionParams < ' a > ) -> Self {
5965 let mut ctx = Self {
6066 tree : params. tree ,
6167 text : & params. text ,
@@ -65,14 +71,30 @@ impl<'a> CompletionContext<'a> {
6571 ts_node : None ,
6672 schema_name : None ,
6773 wrapping_clause_type : None ,
74+ wrapping_statement_range : None ,
6875 is_invocation : false ,
76+ ts_query_executor : None ,
6977 } ;
7078
7179 ctx. gather_tree_context ( ) ;
80+ ctx. dispatch_ts_queries ( ) . await ;
7281
7382 ctx
7483 }
7584
85+ async fn dispatch_ts_queries ( & mut self ) {
86+ let tree = match self . tree . as_ref ( ) {
87+ None => return ,
88+ Some ( t) => t,
89+ } ;
90+
91+ let mut executor = TreeSitterQueriesExecutor :: new ( tree. root_node ( ) , self . text ) ;
92+
93+ executor. add_query_results :: < queries:: RelationMatch > ( ) . await ;
94+
95+ self . ts_query_executor = Some ( executor) ;
96+ }
97+
7698 pub fn get_ts_node_content ( & self , ts_node : tree_sitter:: Node < ' a > ) -> Option < & ' a str > {
7799 let source = self . text ;
78100 match ts_node. utf8_text ( source. as_bytes ( ) ) {
@@ -100,36 +122,38 @@ impl<'a> CompletionContext<'a> {
100122 * We'll therefore adjust the cursor position such that it meets the last node of the AST.
101123 * `select * from use {}` becomes `select * from use{}`.
102124 */
103- let current_node_kind = cursor. node ( ) . kind ( ) ;
125+ let current_node = cursor. node ( ) ;
104126 while cursor. goto_first_child_for_byte ( self . position ) . is_none ( ) && self . position > 0 {
105127 self . position -= 1 ;
106128 }
107129
108- self . gather_context_from_node ( cursor, current_node_kind ) ;
130+ self . gather_context_from_node ( cursor, current_node ) ;
109131 }
110132
111133 fn gather_context_from_node (
112134 & mut self ,
113135 mut cursor : tree_sitter:: TreeCursor < ' a > ,
114- previous_node_kind : & str ,
136+ previous_node : tree_sitter :: Node < ' a > ,
115137 ) {
116138 let current_node = cursor. node ( ) ;
117- let current_node_kind = current_node. kind ( ) ;
118139
119140 // prevent infinite recursion – this can happen if we only have a PROGRAM node
120- if current_node_kind == previous_node_kind {
141+ if current_node . kind ( ) == previous_node . kind ( ) {
121142 self . ts_node = Some ( current_node) ;
122143 return ;
123144 }
124145
125- match previous_node_kind {
126- "statement" => self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ,
146+ match previous_node. kind ( ) {
147+ "statement" => {
148+ self . wrapping_clause_type = current_node. kind ( ) . try_into ( ) . ok ( ) ;
149+ self . wrapping_statement_range = Some ( previous_node. byte_range ( ) ) ;
150+ }
127151 "invocation" => self . is_invocation = true ,
128152
129153 _ => { }
130154 }
131155
132- match current_node_kind {
156+ match current_node . kind ( ) {
133157 "object_reference" => {
134158 let txt = self . get_ts_node_content ( current_node) ;
135159 if let Some ( txt) = txt {
@@ -159,7 +183,7 @@ impl<'a> CompletionContext<'a> {
159183 }
160184
161185 cursor. goto_first_child_for_byte ( self . position ) ;
162- self . gather_context_from_node ( cursor, current_node_kind ) ;
186+ self . gather_context_from_node ( cursor, current_node ) ;
163187 }
164188}
165189
@@ -179,8 +203,8 @@ mod tests {
179203 parser. parse ( input, None ) . expect ( "Unable to parse tree" )
180204 }
181205
182- #[ test]
183- fn identifies_clauses ( ) {
206+ #[ tokio :: test]
207+ async fn identifies_clauses ( ) {
184208 let test_cases = vec ! [
185209 ( format!( "Select {}* from users;" , CURSOR_POS ) , "select" ) ,
186210 ( format!( "Select * from u{};" , CURSOR_POS ) , "from" ) ,
@@ -220,14 +244,14 @@ mod tests {
220244 schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
221245 } ;
222246
223- let ctx = CompletionContext :: new ( & params) ;
247+ let ctx = CompletionContext :: new ( & params) . await ;
224248
225249 assert_eq ! ( ctx. wrapping_clause_type, expected_clause. try_into( ) . ok( ) ) ;
226250 }
227251 }
228252
229- #[ test]
230- fn identifies_schema ( ) {
253+ #[ tokio :: test]
254+ async fn identifies_schema ( ) {
231255 let test_cases = vec ! [
232256 (
233257 format!( "Select * from private.u{}" , CURSOR_POS ) ,
@@ -252,14 +276,14 @@ mod tests {
252276 schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
253277 } ;
254278
255- let ctx = CompletionContext :: new ( & params) ;
279+ let ctx = CompletionContext :: new ( & params) . await ;
256280
257281 assert_eq ! ( ctx. schema_name, expected_schema. map( |f| f. to_string( ) ) ) ;
258282 }
259283 }
260284
261- #[ test]
262- fn identifies_invocation ( ) {
285+ #[ tokio :: test]
286+ async fn identifies_invocation ( ) {
263287 let test_cases = vec ! [
264288 ( format!( "Select * from u{}sers" , CURSOR_POS ) , false ) ,
265289 ( format!( "Select * from u{}sers()" , CURSOR_POS ) , true ) ,
@@ -286,14 +310,14 @@ mod tests {
286310 schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
287311 } ;
288312
289- let ctx = CompletionContext :: new ( & params) ;
313+ let ctx = CompletionContext :: new ( & params) . await ;
290314
291315 assert_eq ! ( ctx. is_invocation, is_invocation) ;
292316 }
293317 }
294318
295- #[ test]
296- fn does_not_fail_on_leading_whitespace ( ) {
319+ #[ tokio :: test]
320+ async fn does_not_fail_on_leading_whitespace ( ) {
297321 let cases = vec ! [
298322 format!( "{} select * from" , CURSOR_POS ) ,
299323 format!( " {} select * from" , CURSOR_POS ) ,
@@ -311,7 +335,7 @@ mod tests {
311335 schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
312336 } ;
313337
314- let ctx = CompletionContext :: new ( & params) ;
338+ let ctx = CompletionContext :: new ( & params) . await ;
315339
316340 let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
317341
@@ -324,8 +348,8 @@ mod tests {
324348 }
325349 }
326350
327- #[ test]
328- fn does_not_fail_on_trailing_whitespace ( ) {
351+ #[ tokio :: test]
352+ async fn does_not_fail_on_trailing_whitespace ( ) {
329353 let query = format ! ( "select * from {}" , CURSOR_POS ) ;
330354
331355 let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
@@ -339,7 +363,7 @@ mod tests {
339363 schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
340364 } ;
341365
342- let ctx = CompletionContext :: new ( & params) ;
366+ let ctx = CompletionContext :: new ( & params) . await ;
343367
344368 let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
345369
@@ -350,8 +374,8 @@ mod tests {
350374 ) ;
351375 }
352376
353- #[ test]
354- fn does_not_fail_with_empty_statements ( ) {
377+ #[ tokio :: test]
378+ async fn does_not_fail_with_empty_statements ( ) {
355379 let query = format ! ( "{}" , CURSOR_POS ) ;
356380
357381 let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
@@ -365,16 +389,16 @@ mod tests {
365389 schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
366390 } ;
367391
368- let ctx = CompletionContext :: new ( & params) ;
392+ let ctx = CompletionContext :: new ( & params) . await ;
369393
370394 let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
371395
372396 assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "" ) ) ;
373397 assert_eq ! ( ctx. wrapping_clause_type, None ) ;
374398 }
375399
376- #[ test]
377- fn does_not_fail_on_incomplete_keywords ( ) {
400+ #[ tokio :: test]
401+ async fn does_not_fail_on_incomplete_keywords ( ) {
378402 // Instead of autocompleting "FROM", we'll assume that the user
379403 // is selecting a certain column name, such as `frozen_account`.
380404 let query = format ! ( "select * fro{}" , CURSOR_POS ) ;
@@ -390,7 +414,7 @@ mod tests {
390414 schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
391415 } ;
392416
393- let ctx = CompletionContext :: new ( & params) ;
417+ let ctx = CompletionContext :: new ( & params) . await ;
394418
395419 let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
396420
0 commit comments