Skip to content

Commit 458067f

Browse files
h9jianggopherbot
authored andcommitted
gopls/internal/golang: improve test package name selection for new file
This commit refines the logic for determining the package name of new test files in gopls. Previously, when creating a new test file, gopls blindly choose x_test package. This commit expands the logic to consider the target function's signature. For golang/vscode-go#1594 Change-Id: Ia78003bf007479e48861ce643c8c7c366ff960a3 Reviewed-on: https://go-review.googlesource.com/c/tools/+/629978 Auto-Submit: Hongxiang Jiang <[email protected]> Reviewed-by: Robert Findley <[email protected]> Reviewed-by: Alan Donovan <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent 68caf84 commit 458067f

File tree

2 files changed

+192
-31
lines changed

2 files changed

+192
-31
lines changed

gopls/internal/golang/addtest.go

+70-25
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
9797
{{- $last := last .Receiver.Constructor.Results}}
9898
{{- if eq $last.Type "error"}}
9999
if err != nil {
100-
t.Fatalf("could not contruct receiver type: %v", err)
100+
t.Fatalf("could not construct receiver type: %v", err)
101101
}
102102
{{- end}}
103103
{{- else}}
@@ -309,13 +309,30 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
309309
// edits contains all the text edits to be applied to the test file.
310310
edits []protocol.TextEdit
311311
// xtest indicates whether the test file use package x or x_test.
312-
// TODO(hxjiang): For now, we try to interpret the user's intention by
313-
// reading the foo_test.go's package name. Instead, we can discuss the option
314-
// to interpret the user's intention by which function they are selecting.
315-
// Have one file for x_test package testing, one file for x package testing.
312+
// TODO(hxjiang): We can discuss the option to interpret the user's
313+
// intention by which function they are selecting. Have one file for
314+
// x_test package testing, one file for x package testing.
316315
xtest = true
317316
)
318317

318+
start, end, err := pgf.RangePos(loc.Range)
319+
if err != nil {
320+
return nil, err
321+
}
322+
323+
path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)
324+
if len(path) < 2 {
325+
return nil, fmt.Errorf("no enclosing function")
326+
}
327+
328+
decl, ok := path[len(path)-2].(*ast.FuncDecl)
329+
if !ok {
330+
return nil, fmt.Errorf("no enclosing function")
331+
}
332+
333+
fn := pkg.TypesInfo().Defs[decl.Name].(*types.Func)
334+
sig := fn.Signature()
335+
319336
testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header)
320337
if err != nil {
321338
if !errors.Is(err, os.ErrNotExist) {
@@ -352,7 +369,53 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
352369
header.WriteString("\n\n")
353370
}
354371

355-
fmt.Fprintf(&header, "package %s_test\n", pkg.Types().Name())
372+
// Determine if a new test file should use in-package test (package x)
373+
// or external test (package x_test). If any of the function parameters
374+
// reference an unexported object, we cannot write out test cases from
375+
// an x_test package.
376+
externalTestOK := func() bool {
377+
if !fn.Exported() {
378+
return false
379+
}
380+
if fn.Signature().Recv() != nil {
381+
if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); ident == nil || !ident.IsExported() {
382+
return false
383+
}
384+
}
385+
refsUnexported := false
386+
ast.Inspect(decl, func(n ast.Node) bool {
387+
// The original function refs to an unexported object from the
388+
// same package, so further inspection is unnecessary.
389+
if refsUnexported {
390+
return false
391+
}
392+
switch t := n.(type) {
393+
case *ast.BlockStmt:
394+
// Avoid inspect the function body.
395+
return false
396+
case *ast.Ident:
397+
// Use test variant (package foo) if the function signature
398+
// references any unexported objects (like types or
399+
// constants) from the same package.
400+
// Note: types.PkgName is excluded from this check as it's
401+
// always defined in the same package.
402+
if obj, ok := pkg.TypesInfo().Uses[t]; ok && !obj.Exported() && obj.Pkg() == pkg.Types() && !is[*types.PkgName](obj) {
403+
refsUnexported = true
404+
}
405+
return false
406+
default:
407+
return true
408+
}
409+
})
410+
return !refsUnexported
411+
}
412+
413+
xtest = externalTestOK()
414+
if xtest {
415+
fmt.Fprintf(&header, "package %s_test\n", pkg.Types().Name())
416+
} else {
417+
fmt.Fprintf(&header, "package %s\n", pkg.Types().Name())
418+
}
356419

357420
// Write the copyright and package decl to the beginning of the file.
358421
edits = append(edits, protocol.TextEdit{
@@ -412,24 +475,6 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
412475
return p.Name()
413476
}
414477

415-
start, end, err := pgf.RangePos(loc.Range)
416-
if err != nil {
417-
return nil, err
418-
}
419-
420-
path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)
421-
if len(path) < 2 {
422-
return nil, fmt.Errorf("no enclosing function")
423-
}
424-
425-
decl, ok := path[len(path)-2].(*ast.FuncDecl)
426-
if !ok {
427-
return nil, fmt.Errorf("no enclosing function")
428-
}
429-
430-
fn := pkg.TypesInfo().Defs[decl.Name].(*types.Func)
431-
sig := fn.Signature()
432-
433478
if xtest {
434479
// Reject if function/method is unexported.
435480
if !fn.Exported() {
@@ -438,7 +483,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
438483

439484
// Reject if receiver is unexported.
440485
if sig.Recv() != nil {
441-
if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); !ident.IsExported() {
486+
if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); ident == nil || !ident.IsExported() {
442487
return nil, fmt.Errorf("cannot add external test for method %s.%s as receiver type is not exported", ident.Name, decl.Name)
443488
}
444489
}

gopls/internal/test/marker/testdata/codeaction/addtest.txt

+122-6
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,22 @@ func Foo(in string) string {return in} //@codeaction("Foo", "source.addTest", ed
100100
-- missingtestfile/missingtestfile.go --
101101
package main
102102

103+
type Bar struct {}
104+
105+
type foo struct {}
106+
103107
func ExportedFunction(in string) string {return in} //@codeaction("ExportedFunction", "source.addTest", edit=missing_test_file_exported_function)
104108

105-
type Bar struct {}
109+
func UnexportedInputParam(in string, f foo) string {return in} //@codeaction("UnexportedInputParam", "source.addTest", edit=missing_test_file_function_unexported_input)
110+
111+
func unexportedFunction(in string) string {return in} //@codeaction("unexportedFunction", "source.addTest", edit=missing_test_file_unexported_function)
106112

107113
func (*Bar) ExportedMethod(in string) string {return in} //@codeaction("ExportedMethod", "source.addTest", edit=missing_test_file_exported_recv_exported_method)
108114

115+
func (*Bar) UnexportedInputParam(in string, f foo) string {return in} //@codeaction("UnexportedInputParam", "source.addTest", edit=missing_test_file_method_unexported_input)
116+
117+
func (*foo) ExportedMethod(in string) string {return in} //@codeaction("ExportedMethod", "source.addTest", edit=missing_test_file_unexported_recv)
118+
109119
-- @missing_test_file_exported_function/missingtestfile/missingtestfile_test.go --
110120
@@ -0,0 +1,26 @@
111121
+package main_test
@@ -164,6 +174,112 @@ func (*Bar) ExportedMethod(in string) string {return in} //@codeaction("Exported
164174
+ })
165175
+ }
166176
+}
177+
-- @missing_test_file_function_unexported_input/missingtestfile/missingtestfile_test.go --
178+
@@ -0,0 +1,24 @@
179+
+package main
180+
+
181+
+import "testing"
182+
+
183+
+func TestUnexportedInputParam(t *testing.T) {
184+
+ tests := []struct {
185+
+ name string // description of this test case
186+
+ // Named input parameters for target function.
187+
+ in string
188+
+ f foo
189+
+ want string
190+
+ }{
191+
+ // TODO: Add test cases.
192+
+ }
193+
+ for _, tt := range tests {
194+
+ t.Run(tt.name, func(t *testing.T) {
195+
+ got := UnexportedInputParam(tt.in, tt.f)
196+
+ // TODO: update the condition below to compare got with tt.want.
197+
+ if true {
198+
+ t.Errorf("UnexportedInputParam() = %v, want %v", got, tt.want)
199+
+ }
200+
+ })
201+
+ }
202+
+}
203+
-- @missing_test_file_method_unexported_input/missingtestfile/missingtestfile_test.go --
204+
@@ -0,0 +1,26 @@
205+
+package main
206+
+
207+
+import "testing"
208+
+
209+
+func TestBar_UnexportedInputParam(t *testing.T) {
210+
+ tests := []struct {
211+
+ name string // description of this test case
212+
+ // Named input parameters for target function.
213+
+ in string
214+
+ f foo
215+
+ want string
216+
+ }{
217+
+ // TODO: Add test cases.
218+
+ }
219+
+ for _, tt := range tests {
220+
+ t.Run(tt.name, func(t *testing.T) {
221+
+ // TODO: construct the receiver type.
222+
+ var b Bar
223+
+ got := b.UnexportedInputParam(tt.in, tt.f)
224+
+ // TODO: update the condition below to compare got with tt.want.
225+
+ if true {
226+
+ t.Errorf("UnexportedInputParam() = %v, want %v", got, tt.want)
227+
+ }
228+
+ })
229+
+ }
230+
+}
231+
-- @missing_test_file_unexported_function/missingtestfile/missingtestfile_test.go --
232+
@@ -0,0 +1,23 @@
233+
+package main
234+
+
235+
+import "testing"
236+
+
237+
+func Test_unexportedFunction(t *testing.T) {
238+
+ tests := []struct {
239+
+ name string // description of this test case
240+
+ // Named input parameters for target function.
241+
+ in string
242+
+ want string
243+
+ }{
244+
+ // TODO: Add test cases.
245+
+ }
246+
+ for _, tt := range tests {
247+
+ t.Run(tt.name, func(t *testing.T) {
248+
+ got := unexportedFunction(tt.in)
249+
+ // TODO: update the condition below to compare got with tt.want.
250+
+ if true {
251+
+ t.Errorf("unexportedFunction() = %v, want %v", got, tt.want)
252+
+ }
253+
+ })
254+
+ }
255+
+}
256+
-- @missing_test_file_unexported_recv/missingtestfile/missingtestfile_test.go --
257+
@@ -0,0 +1,25 @@
258+
+package main
259+
+
260+
+import "testing"
261+
+
262+
+func Test_foo_ExportedMethod(t *testing.T) {
263+
+ tests := []struct {
264+
+ name string // description of this test case
265+
+ // Named input parameters for target function.
266+
+ in string
267+
+ want string
268+
+ }{
269+
+ // TODO: Add test cases.
270+
+ }
271+
+ for _, tt := range tests {
272+
+ t.Run(tt.name, func(t *testing.T) {
273+
+ // TODO: construct the receiver type.
274+
+ var f foo
275+
+ got := f.ExportedMethod(tt.in)
276+
+ // TODO: update the condition below to compare got with tt.want.
277+
+ if true {
278+
+ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want)
279+
+ }
280+
+ })
281+
+ }
282+
+}
167283
-- xpackagetestfile/xpackagetestfile.go --
168284
package main
169285

@@ -519,7 +635,7 @@ package main
519635
+ t.Run(tt.name, func(t *testing.T) {
520636
+ q, err := newQux1()
521637
+ if err != nil {
522-
+ t.Fatalf("could not contruct receiver type: %v", err)
638+
+ t.Fatalf("could not construct receiver type: %v", err)
523639
+ }
524640
+ got := q.method(tt.in)
525641
+ // TODO: update the condition below to compare got with tt.want.
@@ -877,7 +993,7 @@ func (*ReturnPtrError) Method(in string) string {return in} //@codeaction("Metho
877993
+ t.Run(tt.name, func(t *testing.T) {
878994
+ r, err := main.NewReturnTypeError()
879995
+ if err != nil {
880-
+ t.Fatalf("could not contruct receiver type: %v", err)
996+
+ t.Fatalf("could not construct receiver type: %v", err)
881997
+ }
882998
+ got := r.Method(tt.in)
883999
+ // TODO: update the condition below to compare got with tt.want.
@@ -938,7 +1054,7 @@ func (*ReturnPtrError) Method(in string) string {return in} //@codeaction("Metho
9381054
+ t.Run(tt.name, func(t *testing.T) {
9391055
+ r, err := main.NewReturnPtrError()
9401056
+ if err != nil {
941-
+ t.Fatalf("could not contruct receiver type: %v", err)
1057+
+ t.Fatalf("could not construct receiver type: %v", err)
9421058
+ }
9431059
+ got := r.Method(tt.in)
9441060
+ // TODO: update the condition below to compare got with tt.want.
@@ -1018,7 +1134,7 @@ func (*Bar) Method(in string) string {return in} //@codeaction("Method", "source
10181134
+ t.Run(tt.name, func(t *testing.T) {
10191135
+ b, err := main.ABar()
10201136
+ if err != nil {
1021-
+ t.Fatalf("could not contruct receiver type: %v", err)
1137+
+ t.Fatalf("could not construct receiver type: %v", err)
10221138
+ }
10231139
+ got := b.Method(tt.in)
10241140
+ // TODO: update the condition below to compare got with tt.want.
@@ -1403,7 +1519,7 @@ var local renamedctx.Context
14031519
+ t.Run(tt.name, func(t *testing.T) {
14041520
+ f, err := main.NewFoo(renamedctx.Background())
14051521
+ if err != nil {
1406-
+ t.Fatalf("could not contruct receiver type: %v", err)
1522+
+ t.Fatalf("could not construct receiver type: %v", err)
14071523
+ }
14081524
+ got, got2, got3 := f.Method(renamedctx.Background(), "", "")
14091525
+ // TODO: update the condition below to compare got with tt.want.

0 commit comments

Comments
 (0)