@@ -3,13 +3,16 @@ use clippy_utils::diagnostics::{span_lint_and_sugg, span_lint_hir_and_then};
3
3
use clippy_utils:: source:: { snippet, snippet_with_applicability} ;
4
4
use clippy_utils:: sugg:: Sugg ;
5
5
use clippy_utils:: ty:: is_type_diagnostic_item;
6
- use clippy_utils:: { is_trait_method, path_to_local_id} ;
6
+ use clippy_utils:: { can_move_expr_to_closure , is_trait_method, path_to_local , path_to_local_id, CaptureKind } ;
7
7
use if_chain:: if_chain;
8
+ use rustc_data_structures:: fx:: FxHashMap ;
8
9
use rustc_errors:: Applicability ;
9
10
use rustc_hir:: intravisit:: { walk_block, walk_expr, NestedVisitorMap , Visitor } ;
10
- use rustc_hir:: { Block , Expr , ExprKind , HirId , PatKind , StmtKind } ;
11
+ use rustc_hir:: { Block , Expr , ExprKind , HirId , HirIdSet , Local , Mutability , Node , PatKind , Stmt , StmtKind } ;
11
12
use rustc_lint:: LateContext ;
12
13
use rustc_middle:: hir:: map:: Map ;
14
+ use rustc_middle:: ty:: subst:: GenericArgKind ;
15
+ use rustc_middle:: ty:: { self , TyS } ;
13
16
use rustc_span:: sym;
14
17
use rustc_span:: { MultiSpan , Span } ;
15
18
@@ -83,7 +86,8 @@ fn check_needless_collect_indirect_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateCo
83
86
is_type_diagnostic_item( cx, ty, sym:: VecDeque ) ||
84
87
is_type_diagnostic_item( cx, ty, sym:: BinaryHeap ) ||
85
88
is_type_diagnostic_item( cx, ty, sym:: LinkedList ) ;
86
- if let Some ( iter_calls) = detect_iter_and_into_iters( block, id) ;
89
+ let iter_ty = cx. typeck_results( ) . expr_ty( iter_source) ;
90
+ if let Some ( iter_calls) = detect_iter_and_into_iters( block, id, cx, get_captured_ids( cx, iter_ty) ) ;
87
91
if let [ iter_call] = & * iter_calls;
88
92
then {
89
93
let mut used_count_visitor = UsedCountVisitor {
@@ -167,37 +171,89 @@ enum IterFunctionKind {
167
171
Contains ( Span ) ,
168
172
}
169
173
170
- struct IterFunctionVisitor {
171
- uses : Vec < IterFunction > ,
174
+ struct IterFunctionVisitor < ' a , ' tcx > {
175
+ illegal_mutable_capture_ids : HirIdSet ,
176
+ current_mutably_captured_ids : HirIdSet ,
177
+ cx : & ' a LateContext < ' tcx > ,
178
+ uses : Vec < Option < IterFunction > > ,
179
+ hir_id_uses_map : FxHashMap < HirId , usize > ,
180
+ current_statement_hir_id : Option < HirId > ,
172
181
seen_other : bool ,
173
182
target : HirId ,
174
183
}
175
- impl < ' tcx > Visitor < ' tcx > for IterFunctionVisitor {
184
+ impl < ' tcx > Visitor < ' tcx > for IterFunctionVisitor < ' _ , ' tcx > {
185
+ fn visit_block ( & mut self , block : & ' tcx Block < ' tcx > ) {
186
+ for ( expr, hir_id) in block. stmts . iter ( ) . filter_map ( get_expr_and_hir_id_from_stmt) {
187
+ self . visit_block_expr ( expr, hir_id) ;
188
+ }
189
+ if let Some ( expr) = block. expr {
190
+ self . visit_block_expr ( expr, None ) ;
191
+ }
192
+ }
193
+
176
194
fn visit_expr ( & mut self , expr : & ' tcx Expr < ' tcx > ) {
177
195
// Check function calls on our collection
178
196
if let ExprKind :: MethodCall ( method_name, _, [ recv, args @ ..] , _) = & expr. kind {
197
+ if method_name. ident . name == sym ! ( collect) && is_trait_method ( self . cx , expr, sym:: Iterator ) {
198
+ self . current_mutably_captured_ids = get_captured_ids ( self . cx , self . cx . typeck_results ( ) . expr_ty ( recv) ) ;
199
+ self . visit_expr ( recv) ;
200
+ return ;
201
+ }
202
+
179
203
if path_to_local_id ( recv, self . target ) {
180
- match & * method_name. ident . name . as_str ( ) {
181
- "into_iter" => self . uses . push ( IterFunction {
182
- func : IterFunctionKind :: IntoIter ,
183
- span : expr. span ,
184
- } ) ,
185
- "len" => self . uses . push ( IterFunction {
186
- func : IterFunctionKind :: Len ,
187
- span : expr. span ,
188
- } ) ,
189
- "is_empty" => self . uses . push ( IterFunction {
190
- func : IterFunctionKind :: IsEmpty ,
191
- span : expr. span ,
192
- } ) ,
193
- "contains" => self . uses . push ( IterFunction {
194
- func : IterFunctionKind :: Contains ( args[ 0 ] . span ) ,
195
- span : expr. span ,
196
- } ) ,
197
- _ => self . seen_other = true ,
204
+ if self
205
+ . illegal_mutable_capture_ids
206
+ . intersection ( & self . current_mutably_captured_ids )
207
+ . next ( )
208
+ . is_none ( )
209
+ {
210
+ if let Some ( hir_id) = self . current_statement_hir_id {
211
+ self . hir_id_uses_map . insert ( hir_id, self . uses . len ( ) ) ;
212
+ }
213
+ match & * method_name. ident . name . as_str ( ) {
214
+ "into_iter" => self . uses . push ( Some ( IterFunction {
215
+ func : IterFunctionKind :: IntoIter ,
216
+ span : expr. span ,
217
+ } ) ) ,
218
+ "len" => self . uses . push ( Some ( IterFunction {
219
+ func : IterFunctionKind :: Len ,
220
+ span : expr. span ,
221
+ } ) ) ,
222
+ "is_empty" => self . uses . push ( Some ( IterFunction {
223
+ func : IterFunctionKind :: IsEmpty ,
224
+ span : expr. span ,
225
+ } ) ) ,
226
+ "contains" => self . uses . push ( Some ( IterFunction {
227
+ func : IterFunctionKind :: Contains ( args[ 0 ] . span ) ,
228
+ span : expr. span ,
229
+ } ) ) ,
230
+ _ => {
231
+ self . seen_other = true ;
232
+ if let Some ( hir_id) = self . current_statement_hir_id {
233
+ self . hir_id_uses_map . remove ( & hir_id) ;
234
+ }
235
+ } ,
236
+ }
198
237
}
199
238
return ;
200
239
}
240
+
241
+ if let Some ( hir_id) = path_to_local ( recv) {
242
+ if let Some ( index) = self . hir_id_uses_map . remove ( & hir_id) {
243
+ if self
244
+ . illegal_mutable_capture_ids
245
+ . intersection ( & self . current_mutably_captured_ids )
246
+ . next ( )
247
+ . is_none ( )
248
+ {
249
+ if let Some ( hir_id) = self . current_statement_hir_id {
250
+ self . hir_id_uses_map . insert ( hir_id, index) ;
251
+ }
252
+ } else {
253
+ self . uses [ index] = None ;
254
+ }
255
+ }
256
+ }
201
257
}
202
258
// Check if the collection is used for anything else
203
259
if path_to_local_id ( expr, self . target ) {
@@ -213,6 +269,28 @@ impl<'tcx> Visitor<'tcx> for IterFunctionVisitor {
213
269
}
214
270
}
215
271
272
+ impl < ' tcx > IterFunctionVisitor < ' _ , ' tcx > {
273
+ fn visit_block_expr ( & mut self , expr : & ' tcx Expr < ' tcx > , hir_id : Option < HirId > ) {
274
+ self . current_statement_hir_id = hir_id;
275
+ self . current_mutably_captured_ids = get_captured_ids ( self . cx , self . cx . typeck_results ( ) . expr_ty ( expr) ) ;
276
+ self . visit_expr ( expr) ;
277
+ }
278
+ }
279
+
280
+ fn get_expr_and_hir_id_from_stmt < ' v > ( stmt : & ' v Stmt < ' v > ) -> Option < ( & ' v Expr < ' v > , Option < HirId > ) > {
281
+ match stmt. kind {
282
+ StmtKind :: Expr ( expr) | StmtKind :: Semi ( expr) => Some ( ( expr, None ) ) ,
283
+ StmtKind :: Item ( ..) => None ,
284
+ StmtKind :: Local ( Local { init, pat, .. } ) => {
285
+ if let PatKind :: Binding ( _, hir_id, ..) = pat. kind {
286
+ init. map ( |init_expr| ( init_expr, Some ( hir_id) ) )
287
+ } else {
288
+ init. map ( |init_expr| ( init_expr, None ) )
289
+ }
290
+ } ,
291
+ }
292
+ }
293
+
216
294
struct UsedCountVisitor < ' a , ' tcx > {
217
295
cx : & ' a LateContext < ' tcx > ,
218
296
id : HirId ,
@@ -237,12 +315,60 @@ impl<'a, 'tcx> Visitor<'tcx> for UsedCountVisitor<'a, 'tcx> {
237
315
238
316
/// Detect the occurrences of calls to `iter` or `into_iter` for the
239
317
/// given identifier
240
- fn detect_iter_and_into_iters < ' tcx > ( block : & ' tcx Block < ' tcx > , id : HirId ) -> Option < Vec < IterFunction > > {
318
+ fn detect_iter_and_into_iters < ' tcx : ' a , ' a > (
319
+ block : & ' tcx Block < ' tcx > ,
320
+ id : HirId ,
321
+ cx : & ' a LateContext < ' tcx > ,
322
+ captured_ids : HirIdSet ,
323
+ ) -> Option < Vec < IterFunction > > {
241
324
let mut visitor = IterFunctionVisitor {
242
325
uses : Vec :: new ( ) ,
243
326
target : id,
244
327
seen_other : false ,
328
+ cx,
329
+ current_mutably_captured_ids : HirIdSet :: default ( ) ,
330
+ illegal_mutable_capture_ids : captured_ids,
331
+ hir_id_uses_map : FxHashMap :: default ( ) ,
332
+ current_statement_hir_id : None ,
245
333
} ;
246
334
visitor. visit_block ( block) ;
247
- if visitor. seen_other { None } else { Some ( visitor. uses ) }
335
+ if visitor. seen_other {
336
+ None
337
+ } else {
338
+ Some ( visitor. uses . into_iter ( ) . flatten ( ) . collect ( ) )
339
+ }
340
+ }
341
+
342
+ fn get_captured_ids ( cx : & LateContext < ' tcx > , ty : & ' _ TyS < ' _ > ) -> HirIdSet {
343
+ fn get_captured_ids_recursive ( cx : & LateContext < ' tcx > , ty : & ' _ TyS < ' _ > , set : & mut HirIdSet ) {
344
+ match ty. kind ( ) {
345
+ ty:: Adt ( _, generics) => {
346
+ for generic in * generics {
347
+ if let GenericArgKind :: Type ( ty) = generic. unpack ( ) {
348
+ get_captured_ids_recursive ( cx, ty, set) ;
349
+ }
350
+ }
351
+ } ,
352
+ ty:: Closure ( def_id, _) => {
353
+ let closure_hir_node = cx. tcx . hir ( ) . get_if_local ( * def_id) . unwrap ( ) ;
354
+ if let Node :: Expr ( closure_expr) = closure_hir_node {
355
+ can_move_expr_to_closure ( cx, closure_expr)
356
+ . unwrap ( )
357
+ . into_iter ( )
358
+ . for_each ( |( hir_id, capture_kind) | {
359
+ if matches ! ( capture_kind, CaptureKind :: Ref ( Mutability :: Mut ) ) {
360
+ set. insert ( hir_id) ;
361
+ }
362
+ } ) ;
363
+ }
364
+ } ,
365
+ _ => ( ) ,
366
+ }
367
+ }
368
+
369
+ let mut set = HirIdSet :: default ( ) ;
370
+
371
+ get_captured_ids_recursive ( cx, ty, & mut set) ;
372
+
373
+ set
248
374
}
0 commit comments