Skip to content
This repository was archived by the owner on Jun 27, 2023. It is now read-only.

Commit 7233f2a

Browse files
committed
Modify mockgen for embed option
1 parent 5b45562 commit 7233f2a

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

mockgen/mockgen.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ var (
6161
selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.")
6262
writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.")
6363
copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header")
64+
embed = flag.Bool("embed", false, "Embed source interface into generated mock structure")
6465

6566
debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
6667
showVersion = flag.Bool("version", false, "Print version.")
@@ -341,6 +342,9 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
341342
g.packageMap[pth] = pkgName
342343
localNames[pkgName] = true
343344
}
345+
if *embed {
346+
g.packageMap[g.srcPackage] = pkg.Name
347+
}
344348

345349
if *writePkgComment {
346350
g.p("// Package %v is a generated GoMock package.", outputPkgName)
@@ -362,7 +366,7 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
362366
g.p(")")
363367

364368
for _, intf := range pkg.Interfaces {
365-
if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
369+
if err := g.GenerateMockInterface(pkg.Name, intf, outputPackagePath); err != nil {
366370
return err
367371
}
368372
}
@@ -404,14 +408,17 @@ func (g *generator) formattedTypeParams(it *model.Interface, pkgOverride string)
404408
return long.String(), short.String()
405409
}
406410

407-
func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
411+
func (g *generator) GenerateMockInterface(pkgName string, intf *model.Interface, outputPackagePath string) error {
408412
mockType := g.mockName(intf.Name)
409413
longTp, shortTp := g.formattedTypeParams(intf, outputPackagePath)
410414

411415
g.p("")
412416
g.p("// %v is a mock of %v interface.", mockType, intf.Name)
413417
g.p("type %v%v struct {", mockType, longTp)
414418
g.in()
419+
if *embed {
420+
g.p("%v.%v", pkgName, intf.Name)
421+
}
415422
g.p("ctrl *gomock.Controller")
416423
g.p("recorder *%vMockRecorder%v", mockType, shortTp)
417424
g.out()

0 commit comments

Comments
 (0)