@@ -18,7 +18,7 @@ impl TryFrom<&str> for ClauseType {
1818 match value {
1919 "select" => Ok ( Self :: Select ) ,
2020 "where" => Ok ( Self :: Where ) ,
21- "from" => Ok ( Self :: From ) ,
21+ "from" | "keyword_from" => Ok ( Self :: From ) ,
2222 "update" => Ok ( Self :: Update ) ,
2323 "delete" => Ok ( Self :: Delete ) ,
2424 _ => {
@@ -88,13 +88,22 @@ impl<'a> CompletionContext<'a> {
8888
8989 let mut cursor = self . tree . as_ref ( ) . unwrap ( ) . root_node ( ) . walk ( ) ;
9090
91- // go to the statement node that matches the position
91+ /*
92+ * The head node of any treesitter tree is always the "PROGRAM" node.
93+ *
94+ * We want to enter the next layer and focus on the child node that matches the user's cursor position.
95+ * If there is no node under the users position, however, the cursor won't enter the next level – it
96+ * will stay on the Program node.
97+ *
98+ * This might lead to an unexpected context or infinite recursion.
99+ *
100+ * We'll therefore adjust the cursor position such that it meets the last node of the AST.
101+ * `select * from use {}` becomes `select * from use{}`.
102+ */
92103 let current_node_kind = cursor. node ( ) . kind ( ) ;
93-
94- dbg ! ( current_node_kind) ;
95- dbg ! ( self . position) ;
96-
97- cursor. goto_first_child_for_byte ( self . position ) ;
104+ while cursor. goto_first_child_for_byte ( self . position ) . is_none ( ) && self . position > 0 {
105+ self . position -= 1 ;
106+ }
98107
99108 self . gather_context_from_node ( cursor, current_node_kind) ;
100109 }
@@ -107,9 +116,7 @@ impl<'a> CompletionContext<'a> {
107116 let current_node = cursor. node ( ) ;
108117 let current_node_kind = current_node. kind ( ) ;
109118
110- println ! ( "inside.." ) ;
111- dbg ! ( current_node_kind) ;
112-
119+ // prevent infinite recursion – this can happen if we only have a PROGRAM node
113120 if current_node_kind == previous_node_kind {
114121 self . ts_node = Some ( current_node) ;
115122 return ;
@@ -138,9 +145,14 @@ impl<'a> CompletionContext<'a> {
138145 self . wrapping_clause_type = "where" . try_into ( ) . ok ( ) ;
139146 }
140147
148+ "keyword_from" => {
149+ self . wrapping_clause_type = "keyword_from" . try_into ( ) . ok ( ) ;
150+ }
151+
141152 _ => { }
142153 }
143154
155+ // We have arrived at the leaf node
144156 if current_node. child_count ( ) == 0 {
145157 self . ts_node = Some ( current_node) ;
146158 return ;
@@ -153,7 +165,10 @@ impl<'a> CompletionContext<'a> {
153165
154166#[ cfg( test) ]
155167mod tests {
156- use crate :: { context:: CompletionContext , test_helper:: CURSOR_POS } ;
168+ use crate :: {
169+ context:: { ClauseType , CompletionContext } ,
170+ test_helper:: { get_text_and_position, CURSOR_POS } ,
171+ } ;
157172
158173 fn get_tree ( input : & str ) -> tree_sitter:: Tree {
159174 let mut parser = tree_sitter:: Parser :: new ( ) ;
@@ -193,11 +208,11 @@ mod tests {
193208 ) ,
194209 ] ;
195210
196- for ( text, expected_clause) in test_cases {
197- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
198- let text = text. replace ( CURSOR_POS , "" ) ;
211+ for ( query, expected_clause) in test_cases {
212+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
199213
200214 let tree = get_tree ( text. as_str ( ) ) ;
215+
201216 let params = crate :: CompletionParams {
202217 position : ( position as u32 ) . into ( ) ,
203218 text : text,
@@ -226,9 +241,8 @@ mod tests {
226241 ( format!( "Select * from u{}sers()" , CURSOR_POS ) , None ) ,
227242 ] ;
228243
229- for ( text, expected_schema) in test_cases {
230- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
231- let text = text. replace ( CURSOR_POS , "" ) ;
244+ for ( query, expected_schema) in test_cases {
245+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
232246
233247 let tree = get_tree ( text. as_str ( ) ) ;
234248 let params = crate :: CompletionParams {
@@ -261,9 +275,8 @@ mod tests {
261275 ) ,
262276 ] ;
263277
264- for ( text, is_invocation) in test_cases {
265- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
266- let text = text. replace ( CURSOR_POS , "" ) ;
278+ for ( query, is_invocation) in test_cases {
279+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
267280
268281 let tree = get_tree ( text. as_str ( ) ) ;
269282 let params = crate :: CompletionParams {
@@ -280,11 +293,42 @@ mod tests {
280293 }
281294
282295 #[ test]
283- fn get_ts_node_content_does_not_fail_on_error_nodes ( ) {
284- let query = format ! ( "select * from {}" , CURSOR_POS ) ;
296+ fn does_not_fail_on_leading_whitespace ( ) {
297+ let cases = vec ! [
298+ format!( "{} select * from" , CURSOR_POS ) ,
299+ format!( " {} select * from" , CURSOR_POS ) ,
300+ ] ;
301+
302+ for query in cases {
303+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
285304
286- let position = query. find ( CURSOR_POS ) . unwrap ( ) ;
287- let text = query. replace ( CURSOR_POS , "" ) ;
305+ let tree = get_tree ( text. as_str ( ) ) ;
306+
307+ let params = crate :: CompletionParams {
308+ position : ( position as u32 ) . into ( ) ,
309+ text : text,
310+ tree : Some ( & tree) ,
311+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
312+ } ;
313+
314+ let ctx = CompletionContext :: new ( & params) ;
315+
316+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
317+
318+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "select" ) ) ;
319+
320+ assert_eq ! (
321+ ctx. wrapping_clause_type,
322+ Some ( crate :: context:: ClauseType :: Select )
323+ ) ;
324+ }
325+ }
326+
327+ #[ test]
328+ fn does_not_fail_on_trailing_whitespace ( ) {
329+ let query = format ! ( "select * from {}" , CURSOR_POS ) ;
330+
331+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
288332
289333 let tree = get_tree ( text. as_str ( ) ) ;
290334
@@ -297,10 +341,60 @@ mod tests {
297341
298342 let ctx = CompletionContext :: new ( & params) ;
299343
300- let node = ctx. ts_node . map ( |n| n. clone ( ) ) ;
344+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
345+
346+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "from" ) ) ;
347+ assert_eq ! (
348+ ctx. wrapping_clause_type,
349+ Some ( crate :: context:: ClauseType :: From )
350+ ) ;
351+ }
352+
353+ #[ test]
354+ fn does_not_fail_with_empty_statements ( ) {
355+ let query = format ! ( "{}" , CURSOR_POS ) ;
356+
357+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
358+
359+ let tree = get_tree ( text. as_str ( ) ) ;
360+
361+ let params = crate :: CompletionParams {
362+ position : ( position as u32 ) . into ( ) ,
363+ text : text,
364+ tree : Some ( & tree) ,
365+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
366+ } ;
367+
368+ let ctx = CompletionContext :: new ( & params) ;
369+
370+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
371+
372+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "" ) ) ;
373+ assert_eq ! ( ctx. wrapping_clause_type, None ) ;
374+ }
375+
376+ #[ test]
377+ fn does_not_fail_on_incomplete_keywords ( ) {
378+ // Instead of autocompleting "FROM", we'll assume that the user
379+ // is selecting a certain column name, such as `frozen_account`.
380+ let query = format ! ( "select * fro{}" , CURSOR_POS ) ;
381+
382+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
383+
384+ let tree = get_tree ( text. as_str ( ) ) ;
385+
386+ let params = crate :: CompletionParams {
387+ position : ( position as u32 ) . into ( ) ,
388+ text : text,
389+ tree : Some ( & tree) ,
390+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
391+ } ;
392+
393+ let ctx = CompletionContext :: new ( & params) ;
301394
302- println ! ( " node kind: {}" , node . as_ref ( ) . unwrap( ) . kind ( ) ) ;
395+ let node = ctx . ts_node . map ( |n| n . clone ( ) ) . unwrap ( ) ;
303396
304- ctx. get_ts_node_content ( node. unwrap ( ) ) ;
397+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "fro" ) ) ;
398+ assert_eq ! ( ctx. wrapping_clause_type, Some ( ClauseType :: Select ) ) ;
305399 }
306400}
0 commit comments