@@ -37,22 +37,22 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
37
37
// TODO: stricter rules for selectorExpr.
38
38
case * ast.BasicLit , * ast.CompositeLit , * ast.IndexExpr , * ast.SliceExpr ,
39
39
* ast.UnaryExpr , * ast.BinaryExpr , * ast.SelectorExpr :
40
- lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "x" , 0 )
40
+ lhsName , _ := generateAvailableName (expr .Pos (), path , pkg , info , "x" , 0 )
41
41
lhsNames = append (lhsNames , lhsName )
42
42
case * ast.CallExpr :
43
43
tup , ok := info .TypeOf (expr ).(* types.Tuple )
44
44
if ! ok {
45
45
// If the call expression only has one return value, we can treat it the
46
46
// same as our standard extract variable case.
47
- lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "x" , 0 )
47
+ lhsName , _ := generateAvailableName (expr .Pos (), path , pkg , info , "x" , 0 )
48
48
lhsNames = append (lhsNames , lhsName )
49
49
break
50
50
}
51
51
idx := 0
52
52
for i := 0 ; i < tup .Len (); i ++ {
53
53
// Generate a unique variable for each return value.
54
54
var lhsName string
55
- lhsName , idx = generateAvailableIdentifier (expr .Pos (), path , pkg , info , "x" , idx )
55
+ lhsName , idx = generateAvailableName (expr .Pos (), path , pkg , info , "x" , idx )
56
56
lhsNames = append (lhsNames , lhsName )
57
57
}
58
58
default :
@@ -150,12 +150,12 @@ func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.
150
150
return string (content [lineOffset :stmtOffset ]), nil
151
151
}
152
152
153
- // generateAvailableIdentifier adjusts the new function name until there are no collisions in scope.
153
+ // generateAvailableName adjusts the new function name until there are no collisions in scope.
154
154
// Possible collisions include other function and variable names. Returns the next index to check for prefix.
155
- func generateAvailableIdentifier (pos token.Pos , path []ast.Node , pkg * types.Package , info * types.Info , prefix string , idx int ) (string , int ) {
155
+ func generateAvailableName (pos token.Pos , path []ast.Node , pkg * types.Package , info * types.Info , prefix string , idx int ) (string , int ) {
156
156
scopes := CollectScopes (info , path , pos )
157
157
scopes = append (scopes , pkg .Scope ())
158
- return generateIdentifier (idx , prefix , func (name string ) bool {
158
+ return generateName (idx , prefix , func (name string ) bool {
159
159
for _ , scope := range scopes {
160
160
if scope != nil && scope .Lookup (name ) != nil {
161
161
return true
@@ -165,7 +165,31 @@ func generateAvailableIdentifier(pos token.Pos, path []ast.Node, pkg *types.Pack
165
165
})
166
166
}
167
167
168
- func generateIdentifier (idx int , prefix string , hasCollision func (string ) bool ) (string , int ) {
168
+ // generateNameOutsideOfRange is like generateAvailableName, but ignores names
169
+ // declared between start and end for the purposes of detecting conflicts.
170
+ //
171
+ // This is used for function extraction, where [start, end) will be extracted
172
+ // to a new scope.
173
+ func generateNameOutsideOfRange (start , end token.Pos , path []ast.Node , pkg * types.Package , info * types.Info , prefix string , idx int ) (string , int ) {
174
+ scopes := CollectScopes (info , path , start )
175
+ scopes = append (scopes , pkg .Scope ())
176
+ return generateName (idx , prefix , func (name string ) bool {
177
+ for _ , scope := range scopes {
178
+ if scope != nil {
179
+ if obj := scope .Lookup (name ); obj != nil {
180
+ // Only report a collision if the object declaration was outside the
181
+ // extracted range.
182
+ if obj .Pos () < start || end <= obj .Pos () {
183
+ return true
184
+ }
185
+ }
186
+ }
187
+ }
188
+ return false
189
+ })
190
+ }
191
+
192
+ func generateName (idx int , prefix string , hasCollision func (string ) bool ) (string , int ) {
169
193
name := prefix
170
194
if idx != 0 {
171
195
name += fmt .Sprintf ("%d" , idx )
@@ -182,7 +206,7 @@ func generateIdentifier(idx int, prefix string, hasCollision func(string) bool)
182
206
type returnVariable struct {
183
207
// name is the identifier that is used on the left-hand side of the call to
184
208
// the extracted function.
185
- name ast.Expr
209
+ name * ast.Ident
186
210
// decl is the declaration of the variable. It is used in the type signature of the
187
211
// extracted function and for variable declarations.
188
212
decl * ast.Field
@@ -517,7 +541,7 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte
517
541
// statements in the selection. Update the type signature of the extracted
518
542
// function and construct the if statement that will be inserted in the enclosing
519
543
// function.
520
- retVars , ifReturn , err = generateReturnInfo (enclosing , pkg , path , file , info , start , hasNonNestedReturn )
544
+ retVars , ifReturn , err = generateReturnInfo (enclosing , pkg , path , file , info , start , end , hasNonNestedReturn )
521
545
if err != nil {
522
546
return nil , nil , err
523
547
}
@@ -552,7 +576,7 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte
552
576
funName = name
553
577
} else {
554
578
name = "newFunction"
555
- funName , _ = generateAvailableIdentifier (start , path , pkg , info , name , 0 )
579
+ funName , _ = generateAvailableName (start , path , pkg , info , name , 0 )
556
580
}
557
581
extractedFunCall := generateFuncCall (hasNonNestedReturn , hasReturnValues , params ,
558
582
append (returns , getNames (retVars )... ), funName , sym , receiverName )
@@ -1187,12 +1211,12 @@ func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) {
1187
1211
// signature of the extracted function. We prepare names, signatures, and "zero values" that
1188
1212
// represent the new variables. We also use this information to construct the if statement that
1189
1213
// is inserted below the call to the extracted function.
1190
- func generateReturnInfo (enclosing * ast.FuncType , pkg * types.Package , path []ast.Node , file * ast.File , info * types.Info , pos token.Pos , hasNonNestedReturns bool ) ([]* returnVariable , * ast.IfStmt , error ) {
1214
+ func generateReturnInfo (enclosing * ast.FuncType , pkg * types.Package , path []ast.Node , file * ast.File , info * types.Info , start , end token.Pos , hasNonNestedReturns bool ) ([]* returnVariable , * ast.IfStmt , error ) {
1191
1215
var retVars []* returnVariable
1192
1216
var cond * ast.Ident
1193
1217
if ! hasNonNestedReturns {
1194
1218
// Generate information for the added bool value.
1195
- name , _ := generateAvailableIdentifier ( pos , path , pkg , info , "shouldReturn" , 0 )
1219
+ name , _ := generateNameOutsideOfRange ( start , end , path , pkg , info , "shouldReturn" , 0 )
1196
1220
cond = & ast.Ident {Name : name }
1197
1221
retVars = append (retVars , & returnVariable {
1198
1222
name : cond ,
@@ -1202,7 +1226,7 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.
1202
1226
}
1203
1227
// Generate information for the values in the return signature of the enclosing function.
1204
1228
if enclosing .Results != nil {
1205
- idx := 0
1229
+ nameIdx := make ( map [ string ] int ) // last integral suffixes of generated names
1206
1230
for _ , field := range enclosing .Results .List {
1207
1231
typ := info .TypeOf (field .Type )
1208
1232
if typ == nil {
@@ -1213,17 +1237,32 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.
1213
1237
if expr == nil {
1214
1238
return nil , nil , fmt .Errorf ("nil AST expression" )
1215
1239
}
1216
- var name string
1217
- name , idx = generateAvailableIdentifier (pos , path , pkg , info , "returnValue" , idx )
1218
- z := analysisinternal .ZeroValue (file , pkg , typ )
1219
- if z == nil {
1220
- return nil , nil , fmt .Errorf ("can't generate zero value for %T" , typ )
1240
+ names := []string {"" }
1241
+ if len (field .Names ) > 0 {
1242
+ names = nil
1243
+ for _ , n := range field .Names {
1244
+ names = append (names , n .Name )
1245
+ }
1246
+ }
1247
+ for _ , name := range names {
1248
+ bestName := "result"
1249
+ if name != "" && name != "_" {
1250
+ bestName = name
1251
+ } else if n , ok := varNameForType (typ ); ok {
1252
+ bestName = n
1253
+ }
1254
+ retName , idx := generateNameOutsideOfRange (start , end , path , pkg , info , bestName , nameIdx [bestName ])
1255
+ nameIdx [bestName ] = idx
1256
+ z := analysisinternal .ZeroValue (file , pkg , typ )
1257
+ if z == nil {
1258
+ return nil , nil , fmt .Errorf ("can't generate zero value for %T" , typ )
1259
+ }
1260
+ retVars = append (retVars , & returnVariable {
1261
+ name : ast .NewIdent (retName ),
1262
+ decl : & ast.Field {Type : expr },
1263
+ zeroVal : z ,
1264
+ })
1221
1265
}
1222
- retVars = append (retVars , & returnVariable {
1223
- name : ast .NewIdent (name ),
1224
- decl : & ast.Field {Type : expr },
1225
- zeroVal : z ,
1226
- })
1227
1266
}
1228
1267
}
1229
1268
var ifReturn * ast.IfStmt
@@ -1240,6 +1279,48 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.
1240
1279
return retVars , ifReturn , nil
1241
1280
}
1242
1281
1282
+ type objKey struct { pkg , name string }
1283
+
1284
+ // conventionalVarNames specifies conventional names for variables with various
1285
+ // standard library types.
1286
+ //
1287
+ // Keep this up to date with completion.conventionalAcronyms.
1288
+ //
1289
+ // TODO(rfindley): consider factoring out a "conventions" library.
1290
+ var conventionalVarNames = map [objKey ]string {
1291
+ {"" , "error" }: "err" ,
1292
+ {"context" , "Context" }: "ctx" ,
1293
+ {"sql" , "Tx" }: "tx" ,
1294
+ {"http" , "ResponseWriter" }: "rw" , // Note: same as [AbbreviateVarName].
1295
+ }
1296
+
1297
+ // varNameForTypeName chooses a "good" name for a variable with the given type,
1298
+ // if possible. Otherwise, it returns "", false.
1299
+ //
1300
+ // For special types, it uses known conventional names.
1301
+ func varNameForType (t types.Type ) (string , bool ) {
1302
+ var typeName string
1303
+ if tn , ok := t .(interface { Obj () * types.TypeName }); ok {
1304
+ obj := tn .Obj ()
1305
+ k := objKey {name : obj .Name ()}
1306
+ if obj .Pkg () != nil {
1307
+ k .pkg = obj .Pkg ().Name ()
1308
+ }
1309
+ if name , ok := conventionalVarNames [k ]; ok {
1310
+ return name , true
1311
+ }
1312
+ typeName = obj .Name ()
1313
+ } else if b , ok := t .(* types.Basic ); ok {
1314
+ typeName = b .Name ()
1315
+ }
1316
+
1317
+ if typeName == "" {
1318
+ return "" , false
1319
+ }
1320
+
1321
+ return AbbreviateVarName (typeName ), true
1322
+ }
1323
+
1243
1324
// adjustReturnStatements adds "zero values" of the given types to each return statement
1244
1325
// in the given AST node.
1245
1326
func adjustReturnStatements (returnTypes []* ast.Field , seenVars map [types.Object ]ast.Expr , file * ast.File , pkg * types.Package , extractedBlock * ast.BlockStmt ) error {
@@ -1346,9 +1427,8 @@ func initializeVars(uninitialized []types.Object, retVars []*returnVariable, see
1346
1427
// Each variable added from a return statement in the selection
1347
1428
// must be initialized.
1348
1429
for i , retVar := range retVars {
1349
- n := retVar .name .(* ast.Ident )
1350
1430
valSpec := & ast.ValueSpec {
1351
- Names : []* ast.Ident {n },
1431
+ Names : []* ast.Ident {retVar . name },
1352
1432
Type : retVars [i ].decl .Type ,
1353
1433
}
1354
1434
genDecl := & ast.GenDecl {
0 commit comments