@@ -18,7 +18,7 @@ impl TryFrom<&str> for ClauseType {
1818 match value {
1919 "select" => Ok ( Self :: Select ) ,
2020 "where" => Ok ( Self :: Where ) ,
21- "from" => Ok ( Self :: From ) ,
21+ "from" | "keyword_from" => Ok ( Self :: From ) ,
2222 "update" => Ok ( Self :: Update ) ,
2323 "delete" => Ok ( Self :: Delete ) ,
2424 _ => {
@@ -88,10 +88,22 @@ impl<'a> CompletionContext<'a> {
8888
8989 let mut cursor = self . tree . as_ref ( ) . unwrap ( ) . root_node ( ) . walk ( ) ;
9090
91- // go to the statement node that matches the position
91+ /*
92+ * The head node of any treesitter tree is always the "PROGRAM" node.
93+ *
94+ * We want to enter the next layer and focus on the child node that matches the user's cursor position.
95+ * If there is no node under the users position, however, the cursor won't enter the next level – it
96+ * will stay on the Program node.
97+ *
98+ * This might lead to an unexpected context or infinite recursion.
99+ *
100+ * We'll therefore adjust the cursor position such that it meets the last node of the AST.
101+ * `select * from use {}` becomes `select * from use{}`.
102+ */
92103 let current_node_kind = cursor. node ( ) . kind ( ) ;
93-
94- cursor. goto_first_child_for_byte ( self . position ) ;
104+ while cursor. goto_first_child_for_byte ( self . position ) . is_none ( ) && self . position > 0 {
105+ self . position -= 1 ;
106+ }
95107
96108 self . gather_context_from_node ( cursor, current_node_kind) ;
97109 }
@@ -104,6 +116,12 @@ impl<'a> CompletionContext<'a> {
104116 let current_node = cursor. node ( ) ;
105117 let current_node_kind = current_node. kind ( ) ;
106118
119+ // prevent infinite recursion – this can happen if we only have a PROGRAM node
120+ if current_node_kind == previous_node_kind {
121+ self . ts_node = Some ( current_node) ;
122+ return ;
123+ }
124+
107125 match previous_node_kind {
108126 "statement" => self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ,
109127 "invocation" => self . is_invocation = true ,
@@ -127,9 +145,14 @@ impl<'a> CompletionContext<'a> {
127145 self . wrapping_clause_type = "where" . try_into ( ) . ok ( ) ;
128146 }
129147
148+ "keyword_from" => {
149+ self . wrapping_clause_type = "keyword_from" . try_into ( ) . ok ( ) ;
150+ }
151+
130152 _ => { }
131153 }
132154
155+ // We have arrived at the leaf node
133156 if current_node. child_count ( ) == 0 {
134157 self . ts_node = Some ( current_node) ;
135158 return ;
@@ -142,7 +165,10 @@ impl<'a> CompletionContext<'a> {
142165
143166#[ cfg( test) ]
144167mod tests {
145- use crate :: { context:: CompletionContext , test_helper:: CURSOR_POS } ;
168+ use crate :: {
169+ context:: { ClauseType , CompletionContext } ,
170+ test_helper:: { get_text_and_position, CURSOR_POS } ,
171+ } ;
146172
147173 fn get_tree ( input : & str ) -> tree_sitter:: Tree {
148174 let mut parser = tree_sitter:: Parser :: new ( ) ;
@@ -182,11 +208,11 @@ mod tests {
182208 ) ,
183209 ] ;
184210
185- for ( text, expected_clause) in test_cases {
186- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
187- let text = text. replace ( CURSOR_POS , "" ) ;
211+ for ( query, expected_clause) in test_cases {
212+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
188213
189214 let tree = get_tree ( text. as_str ( ) ) ;
215+
190216 let params = crate :: CompletionParams {
191217 position : ( position as u32 ) . into ( ) ,
192218 text : text,
@@ -215,9 +241,8 @@ mod tests {
215241 ( format!( "Select * from u{}sers()" , CURSOR_POS ) , None ) ,
216242 ] ;
217243
218- for ( text, expected_schema) in test_cases {
219- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
220- let text = text. replace ( CURSOR_POS , "" ) ;
244+ for ( query, expected_schema) in test_cases {
245+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
221246
222247 let tree = get_tree ( text. as_str ( ) ) ;
223248 let params = crate :: CompletionParams {
@@ -250,9 +275,8 @@ mod tests {
250275 ) ,
251276 ] ;
252277
253- for ( text, is_invocation) in test_cases {
254- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
255- let text = text. replace ( CURSOR_POS , "" ) ;
278+ for ( query, is_invocation) in test_cases {
279+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
256280
257281 let tree = get_tree ( text. as_str ( ) ) ;
258282 let params = crate :: CompletionParams {
@@ -267,4 +291,110 @@ mod tests {
267291 assert_eq ! ( ctx. is_invocation, is_invocation) ;
268292 }
269293 }
294+
295+ #[ test]
296+ fn does_not_fail_on_leading_whitespace ( ) {
297+ let cases = vec ! [
298+ format!( "{} select * from" , CURSOR_POS ) ,
299+ format!( " {} select * from" , CURSOR_POS ) ,
300+ ] ;
301+
302+ for query in cases {
303+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
304+
305+ let tree = get_tree ( text. as_str ( ) ) ;
306+
307+ let params = crate :: CompletionParams {
308+ position : ( position as u32 ) . into ( ) ,
309+ text : text,
310+ tree : Some ( & tree) ,
311+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
312+ } ;
313+
314+ let ctx = CompletionContext :: new ( & params) ;
315+
316+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
317+
318+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "select" ) ) ;
319+
320+ assert_eq ! (
321+ ctx. wrapping_clause_type,
322+ Some ( crate :: context:: ClauseType :: Select )
323+ ) ;
324+ }
325+ }
326+
327+ #[ test]
328+ fn does_not_fail_on_trailing_whitespace ( ) {
329+ let query = format ! ( "select * from {}" , CURSOR_POS ) ;
330+
331+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
332+
333+ let tree = get_tree ( text. as_str ( ) ) ;
334+
335+ let params = crate :: CompletionParams {
336+ position : ( position as u32 ) . into ( ) ,
337+ text : text,
338+ tree : Some ( & tree) ,
339+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
340+ } ;
341+
342+ let ctx = CompletionContext :: new ( & params) ;
343+
344+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
345+
346+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "from" ) ) ;
347+ assert_eq ! (
348+ ctx. wrapping_clause_type,
349+ Some ( crate :: context:: ClauseType :: From )
350+ ) ;
351+ }
352+
353+ #[ test]
354+ fn does_not_fail_with_empty_statements ( ) {
355+ let query = format ! ( "{}" , CURSOR_POS ) ;
356+
357+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
358+
359+ let tree = get_tree ( text. as_str ( ) ) ;
360+
361+ let params = crate :: CompletionParams {
362+ position : ( position as u32 ) . into ( ) ,
363+ text : text,
364+ tree : Some ( & tree) ,
365+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
366+ } ;
367+
368+ let ctx = CompletionContext :: new ( & params) ;
369+
370+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
371+
372+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "" ) ) ;
373+ assert_eq ! ( ctx. wrapping_clause_type, None ) ;
374+ }
375+
376+ #[ test]
377+ fn does_not_fail_on_incomplete_keywords ( ) {
378+ // Instead of autocompleting "FROM", we'll assume that the user
379+ // is selecting a certain column name, such as `frozen_account`.
380+ let query = format ! ( "select * fro{}" , CURSOR_POS ) ;
381+
382+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
383+
384+ let tree = get_tree ( text. as_str ( ) ) ;
385+
386+ let params = crate :: CompletionParams {
387+ position : ( position as u32 ) . into ( ) ,
388+ text : text,
389+ tree : Some ( & tree) ,
390+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
391+ } ;
392+
393+ let ctx = CompletionContext :: new ( & params) ;
394+
395+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
396+
397+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "fro" ) ) ;
398+ assert_eq ! ( ctx. wrapping_clause_type, Some ( ClauseType :: Select ) ) ;
399+ }
270400}
0 commit comments