Skip to content

Commit 3e4a92e

Browse files
committed
zz
1 parent 80799d6 commit 3e4a92e

File tree

2 files changed

+312
-2
lines changed

2 files changed

+312
-2
lines changed

gopls/internal/golang/highlight.go

+277-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ import (
1010
"go/ast"
1111
"go/token"
1212
"go/types"
13+
"io"
14+
"strings"
1315

1416
"golang.org/x/tools/go/ast/astutil"
1517
"golang.org/x/tools/gopls/internal/cache"
1618
"golang.org/x/tools/gopls/internal/file"
1719
"golang.org/x/tools/gopls/internal/protocol"
20+
gastutil "golang.org/x/tools/gopls/internal/util/astutil"
1821
"golang.org/x/tools/internal/event"
1922
)
2023

@@ -49,7 +52,7 @@ func Highlight(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, po
4952
}
5053
}
5154
}
52-
result, err := highlightPath(path, pgf.File, pkg.TypesInfo())
55+
result, err := highlightPath(path, pgf.File, pkg.TypesInfo(), pos)
5356
if err != nil {
5457
return nil, err
5558
}
@@ -69,8 +72,20 @@ func Highlight(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, po
6972

7073
// highlightPath returns ranges to highlight for the given enclosing path,
7174
// which should be the result of astutil.PathEnclosingInterval.
72-
func highlightPath(path []ast.Node, file *ast.File, info *types.Info) (map[posRange]protocol.DocumentHighlightKind, error) {
75+
func highlightPath(path []ast.Node, file *ast.File, info *types.Info, pos token.Pos) (map[posRange]protocol.DocumentHighlightKind, error) {
7376
result := make(map[posRange]protocol.DocumentHighlightKind)
77+
// Inside a printf-style call?
78+
for _, node := range path {
79+
if call, ok := node.(*ast.CallExpr); ok {
80+
for _, args := range call.Args {
81+
// Only try when pos is in right side of the format String.
82+
if basicList, ok := args.(*ast.BasicLit); ok && basicList.Pos() < pos &&
83+
basicList.Kind == token.STRING && strings.Contains(basicList.Value, "%") {
84+
highlightPrintf(basicList, call, pos, result)
85+
}
86+
}
87+
}
88+
}
7489
switch node := path[0].(type) {
7590
case *ast.BasicLit:
7691
// Import path string literal?
@@ -131,6 +146,266 @@ func highlightPath(path []ast.Node, file *ast.File, info *types.Info) (map[posRa
131146
return result, nil
132147
}
133148

149+
// highlightPrintf identifies and highlights the relationships between placeholders
150+
// in a format string and their corresponding variadic arguments in a printf-style
151+
// function call.
152+
//
153+
// For example:
154+
//
155+
// fmt.Printf("Hello %s, you scored %d", name, score)
156+
//
157+
// If the cursor is on %s or name, highlightPrintf will highlight %s as a write operation,
158+
// and name as a read operation.
159+
func highlightPrintf(directive *ast.BasicLit, call *ast.CallExpr, pos token.Pos, result map[posRange]protocol.DocumentHighlightKind) {
160+
format := directive.Value
161+
// Give up when encounter '% %', '%%' for simplicity.
162+
// For example:
163+
//
164+
// fmt.Printf("hello % %s, %-2.3d\n", "world", 123)
165+
//
166+
// The implementation of fmt.doPrintf will ignore first two '%'s,
167+
// causing arguments count bigger than placeholders count (2 > 1), producing
168+
// "%!(EXTRA" error string in formatFunc and incorrect highlight range.
169+
//
170+
// fmt.Printf("%% %s, %-2.3d\n", "world", 123)
171+
//
172+
// This case it will not emit errors, but the recording range of parsef is going to
173+
// shift left because two % are interpreted as one %(escaped), so it becomes:
174+
// fmt.Printf("%% %s, %-2.3d\n", "world", 123)
175+
// | | the range will include a whitespace in left of %s
176+
for i := range len(format) {
177+
if format[i] == '%' {
178+
for j := i + 1; j < len(format); j++ {
179+
c := format[j]
180+
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') {
181+
break
182+
}
183+
if c == '%' {
184+
return
185+
}
186+
}
187+
}
188+
}
189+
190+
// Computation is based on count of '%', when placeholders and variadic arguments missmatch,
191+
// users are most likely completing arguments, so try to highlight any unfinished one.
192+
// Make sure variadic arguments passed to parsef matches correct count of '%'.
193+
expectedVariadicArgs := make([]ast.Expr, strings.Count(format, "%"))
194+
firstVariadic := -1
195+
for i, arg := range call.Args {
196+
if directive == arg {
197+
firstVariadic = i + 1
198+
argsLen := len(call.Args) - i - 1
199+
if argsLen > len(expectedVariadicArgs) {
200+
// Translate from Printf(a0,"%d %d",5, 6, 7) to [5, 6]
201+
copy(expectedVariadicArgs, call.Args[firstVariadic:firstVariadic+len(expectedVariadicArgs)])
202+
} else {
203+
// Translate from Printf(a0,"%d %d %s",5, 6) to [5, 6, nil]
204+
copy(expectedVariadicArgs[:argsLen], call.Args[firstVariadic:])
205+
}
206+
break
207+
}
208+
}
209+
var percent formatPercent
210+
// Get a position-ordered slice describing each directive item.
211+
parsedDirectives := parsef(format, directive.Pos(), expectedVariadicArgs...)
212+
// Cursor in argument.
213+
if pos > directive.End() {
214+
// Which variadic argument cursor sits inside.
215+
for i := firstVariadic; i < len(call.Args); i++ {
216+
if gastutil.NodeContains(call.Args[i], pos) {
217+
// Offset relative to parsedDirectives.
218+
// (Printf(a0,"%d %d %s",5, 6), firstVariadic=2,i=3)
219+
// ^ cursor here
220+
// -> ([5, 6, nil], firstVariadic=1)
221+
// ^
222+
firstVariadic = i - firstVariadic
223+
break
224+
}
225+
}
226+
index := -1
227+
for _, part := range parsedDirectives {
228+
switch part := part.(type) {
229+
case formatPercent:
230+
percent = part
231+
index++
232+
case formatVerb:
233+
if token.Pos(percent).IsValid() {
234+
if index == firstVariadic {
235+
// Placeholders behave like writting values from arguments to themselves,
236+
// so highlight them with Write semantic.
237+
highlightRange(result, token.Pos(percent), part.rang.end, protocol.Write)
238+
highlightRange(result, part.operand.Pos(), part.operand.End(), protocol.Read)
239+
return
240+
}
241+
percent = formatPercent(token.NoPos)
242+
}
243+
}
244+
}
245+
} else {
246+
// Cursor in format string.
247+
for _, part := range parsedDirectives {
248+
switch part := part.(type) {
249+
case formatPercent:
250+
percent = part
251+
case formatVerb:
252+
if token.Pos(percent).IsValid() {
253+
if token.Pos(percent) <= pos && pos <= part.rang.end {
254+
highlightRange(result, token.Pos(percent), part.rang.end, protocol.Write)
255+
if part.operand != nil {
256+
highlightRange(result, part.operand.Pos(), part.operand.End(), protocol.Read)
257+
}
258+
return
259+
}
260+
percent = formatPercent(token.NoPos)
261+
}
262+
}
263+
}
264+
}
265+
}
266+
267+
// Below are formatting directives definitions.
268+
type formatPercent token.Pos
269+
type formatLiteral struct {
270+
literal string
271+
rang posRange
272+
}
273+
type formatFlags struct {
274+
flag string
275+
rang posRange
276+
}
277+
type formatWidth struct {
278+
width int
279+
rang posRange
280+
}
281+
type formatPrec struct {
282+
prec int
283+
rang posRange
284+
}
285+
type formatVerb struct {
286+
verb rune
287+
rang posRange
288+
operand ast.Expr // verb's corresponding operand, may be nil
289+
}
290+
291+
type formatFunc func(fmt.State, rune)
292+
293+
var _ fmt.Formatter = formatFunc(nil)
294+
295+
func (f formatFunc) Format(st fmt.State, verb rune) { f(st, verb) }
296+
297+
// parsef parses a printf-style format string into its constituent components together with
298+
// their position in the source code, including [formatLiteral], formatting directives
299+
// [formatFlags], [formatPrecision], [formatWidth], [formatPrecision], [formatVerb], and its operand.
300+
func parsef(format string, pos token.Pos, args ...ast.Expr) []any {
301+
const sep = "!!!GOPLS_SEP!!!"
302+
// A Conversion represents a single % operation and its operand.
303+
type conversion struct {
304+
verb rune
305+
width int // or -1
306+
prec int // or -1
307+
flag string // some of "-+# 0"
308+
operand ast.Expr
309+
}
310+
var convs []conversion
311+
wrappers := make([]any, len(args))
312+
for i, operand := range args {
313+
wrappers[i] = formatFunc(func(st fmt.State, verb rune) {
314+
io.WriteString(st, sep)
315+
width, ok := st.Width()
316+
if !ok {
317+
width = -1
318+
}
319+
prec, ok := st.Precision()
320+
if !ok {
321+
prec = -1
322+
}
323+
flag := ""
324+
for _, b := range "-+# 0" {
325+
if st.Flag(int(b)) {
326+
flag += string(b)
327+
}
328+
}
329+
convs = append(convs, conversion{
330+
verb: verb,
331+
width: width,
332+
prec: prec,
333+
flag: flag,
334+
operand: operand,
335+
})
336+
})
337+
}
338+
339+
// Interleave the literals and the conversions.
340+
var directives []any
341+
for i, word := range strings.Split(fmt.Sprintf(format, wrappers...), sep) {
342+
if word != "" {
343+
directives = append(directives, formatLiteral{
344+
literal: word,
345+
rang: posRange{
346+
start: pos,
347+
end: pos + token.Pos(len(word)),
348+
},
349+
})
350+
pos = pos + token.Pos(len(word))
351+
}
352+
if i < len(convs) {
353+
conv := convs[i]
354+
// Collect %.
355+
directives = append(directives, formatPercent(pos))
356+
pos += 1
357+
// Collect flags.
358+
if flag := conv.flag; flag != "" {
359+
length := token.Pos(len(conv.flag))
360+
directives = append(directives, formatFlags{
361+
flag: flag,
362+
rang: posRange{
363+
start: pos,
364+
end: pos + length,
365+
},
366+
})
367+
pos += length
368+
}
369+
// Collect width.
370+
if width := conv.width; conv.width != -1 {
371+
length := token.Pos(len(fmt.Sprintf("%d", conv.width)))
372+
directives = append(directives, formatWidth{
373+
width: width,
374+
rang: posRange{
375+
start: pos,
376+
end: pos + length,
377+
},
378+
})
379+
pos += length
380+
}
381+
// Collect precision, which starts with a dot.
382+
if prec := conv.prec; conv.prec != -1 {
383+
length := token.Pos(len(fmt.Sprintf("%d", conv.prec))) + 1
384+
directives = append(directives, formatPrec{
385+
prec: prec,
386+
rang: posRange{
387+
start: pos,
388+
end: pos + length,
389+
},
390+
})
391+
pos += length
392+
}
393+
// Collect verb, which must be present.
394+
length := token.Pos(len(string(conv.verb)))
395+
directives = append(directives, formatVerb{
396+
verb: conv.verb,
397+
rang: posRange{
398+
start: pos,
399+
end: pos + length,
400+
},
401+
operand: conv.operand,
402+
})
403+
pos += length
404+
}
405+
}
406+
return directives
407+
}
408+
134409
type posRange struct {
135410
start, end token.Pos
136411
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
This test checks functionality of the printf-like directives and operands highlight.
2+
3+
-- flags --
4+
-ignore_extra_diags
5+
6+
-- highlights.go --
7+
package highlightprintf
8+
9+
import (
10+
"fmt"
11+
)
12+
13+
func BasicPrintfHighlights() {
14+
fmt.Printf("Hello %s, you have %d new messages!", "Alice", 5) //@hiloc(normals, "%s", write),hiloc(normalarg0, "\"Alice\"", read),highlightall(normals, normalarg0)
15+
fmt.Printf("Hello %s, you have %d new messages!", "Alice", 5) //@hiloc(normald, "%d", write),hiloc(normalargs1, "5", read),highlightall(normald, normalargs1)
16+
}
17+
18+
func ComplexPrintfHighlights() {
19+
fmt.Printf("Hello %#3.4s, you have %-2.3d new messages!", "Alice", 5) //@hiloc(complexs, "%#3.4s", write),hiloc(complexarg0, "\"Alice\"", read),highlightall(complexs, complexarg0)
20+
fmt.Printf("Hello %#3.4s, you have %-2.3d new messages!", "Alice", 5) //@hiloc(complexd, "%-2.3d", write),hiloc(complexarg1, "5", read),highlightall(complexd, complexarg1)
21+
}
22+
23+
func MissingDirectives() {
24+
fmt.Printf("Hello %s, you have 5 new messages!", "Alice", 5) //@hiloc(missings, "%s", write),hiloc(missingargs0, "\"Alice\"", read),highlightall(missings, missingargs0)
25+
}
26+
27+
func TooManyDirectives() {
28+
fmt.Printf("Hello %s, you have %d new %s %q messages!", "Alice", 5) //@hiloc(toomanys, "%s", write),hiloc(toomanyargs0, "\"Alice\"", read),highlightall(toomanys, toomanyargs0)
29+
fmt.Printf("Hello %s, you have %d new %s %q messages!", "Alice", 5) //@hiloc(toomanyd, "%d", write),hiloc(toomanyargs1, "5", read),highlightall(toomanyd, toomanyargs1)
30+
}
31+
32+
func SpecialChars() {
33+
fmt.Printf("Hello \n %s, you \t \n have %d new messages!", "Alice", 5) //@hiloc(specials, "%s", write),hiloc(specialargs0, "\"Alice\"", read),highlightall(specials, specialargs0)
34+
fmt.Printf("Hello \n %s, you \t \n have %d new messages!", "Alice", 5) //@hiloc(speciald, "%d", write),hiloc(specialargs1, "5", read),highlightall(speciald, specialargs1)
35+
}

0 commit comments

Comments
 (0)