From c36e8fcc3c128e31eaa643904c9b8b33d9a5c1a6 Mon Sep 17 00:00:00 2001 From: Yotam Ofek Date: Fri, 11 Apr 2025 14:26:26 +0000 Subject: [PATCH] In `rustc_mir_tranform`, iterate over index newtypes instead of ints --- compiler/rustc_index_macros/src/newtype.rs | 7 +++ compiler/rustc_mir_transform/src/coroutine.rs | 53 ++++++++----------- .../src/early_otherwise_branch.rs | 3 +- .../rustc_mir_transform/src/match_branches.rs | 17 +++--- .../src/multiple_return_terminators.rs | 14 +++-- compiler/rustc_mir_transform/src/validate.rs | 5 +- 6 files changed, 46 insertions(+), 53 deletions(-) diff --git a/compiler/rustc_index_macros/src/newtype.rs b/compiler/rustc_index_macros/src/newtype.rs index f0b58eabbff9a..eedbe630cf2c4 100644 --- a/compiler/rustc_index_macros/src/newtype.rs +++ b/compiler/rustc_index_macros/src/newtype.rs @@ -257,6 +257,13 @@ impl Parse for Newtype { } } + impl std::ops::AddAssign for #name { + #[inline] + fn add_assign(&mut self, other: usize) { + *self = *self + other; + } + } + impl rustc_index::Idx for #name { #[inline] fn new(value: usize) -> Self { diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index 04d96f117072f..80c729d66b1ec 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -547,7 +547,7 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None); - for bb in START_BLOCK..body.basic_blocks.next_index() { + for bb in body.basic_blocks.indices() { let bb_data = &body[bb]; if bb_data.is_cleanup { continue; @@ -556,11 +556,11 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { match &bb_data.terminator().kind { TerminatorKind::Call { func, .. } => { let func_ty = func.ty(body, tcx); - if let ty::FnDef(def_id, _) = *func_ty.kind() { - if def_id == get_context_def_id { - let local = eliminate_get_context_call(&mut body[bb]); - replace_resume_ty_local(tcx, body, local, context_mut_ref); - } + if let ty::FnDef(def_id, _) = *func_ty.kind() + && def_id == get_context_def_id + { + let local = eliminate_get_context_call(&mut body[bb]); + replace_resume_ty_local(tcx, body, local, context_mut_ref); } } TerminatorKind::Yield { resume_arg, .. } => { @@ -1057,7 +1057,7 @@ fn insert_switch<'tcx>( let blocks = body.basic_blocks_mut().iter_mut(); for target in blocks.flat_map(|b| b.terminator_mut().successors_mut()) { - *target = BasicBlock::new(target.index() + 1); + *target += 1; } } @@ -1209,14 +1209,8 @@ fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::Typing } // If there's a return terminator the function may return. - for block in body.basic_blocks.iter() { - if let TerminatorKind::Return = block.terminator().kind { - return true; - } - } - + body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return)) // Otherwise the function can't return. - false } fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool { @@ -1293,12 +1287,12 @@ fn create_coroutine_resume_function<'tcx>( kind: TerminatorKind::Goto { target: poison_block }, }; } - } else if !block.is_cleanup { + } else if !block.is_cleanup // Any terminators that *can* unwind but don't have an unwind target set are also // pointed at our poisoning block (unless they're part of the cleanup path). - if let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut() { - *unwind = UnwindAction::Cleanup(poison_block); - } + && let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut() + { + *unwind = UnwindAction::Cleanup(poison_block); } } } @@ -1340,12 +1334,14 @@ fn create_coroutine_resume_function<'tcx>( make_coroutine_state_argument_indirect(tcx, body); match transform.coroutine_kind { + CoroutineKind::Coroutine(_) + | CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) => + { + make_coroutine_state_argument_pinned(tcx, body); + } // Iterator::next doesn't accept a pinned argument, // unlike for all other coroutine kinds. CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {} - _ => { - make_coroutine_state_argument_pinned(tcx, body); - } } // Make sure we remove dead blocks to remove @@ -1408,8 +1404,7 @@ fn create_cases<'tcx>( let mut statements = Vec::new(); // Create StorageLive instructions for locals with live storage - for i in 0..(body.local_decls.len()) { - let l = Local::new(i); + for l in body.local_decls.indices() { let needs_storage_live = point.storage_liveness.contains(l) && !transform.remap.contains(l) && !transform.always_live_locals.contains(l); @@ -1535,15 +1530,10 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform { let coroutine_kind = body.coroutine_kind().unwrap(); // Get the discriminant type and args which typeck computed - let (discr_ty, movable) = match *coroutine_ty.kind() { - ty::Coroutine(_, args) => { - let args = args.as_coroutine(); - (args.discr_ty(tcx), coroutine_kind.movability() == hir::Movability::Movable) - } - _ => { - tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}")); - } + let ty::Coroutine(_, args) = coroutine_ty.kind() else { + tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}")); }; + let discr_ty = args.as_coroutine().discr_ty(tcx); let new_ret_ty = match coroutine_kind { CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => { @@ -1610,6 +1600,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform { let always_live_locals = always_storage_live_locals(body); + let movable = coroutine_kind.movability() == hir::Movability::Movable; let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs index 57f7893be1b8c..d49f5d9f9c385 100644 --- a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs +++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs @@ -103,9 +103,8 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch { let mut should_cleanup = false; // Also consider newly generated bbs in the same pass - for i in 0..body.basic_blocks.len() { + for parent in body.basic_blocks.indices() { let bbs = &*body.basic_blocks; - let parent = BasicBlock::from_usize(i); let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue }; trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}"); diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 0d9d0368d3729..5059837328e24 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -20,13 +20,11 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let typing_env = body.typing_env(tcx); let mut should_cleanup = false; - for i in 0..body.basic_blocks.len() { - let bbs = &*body.basic_blocks; - let bb_idx = BasicBlock::from_usize(i); - match bbs[bb_idx].terminator().kind { + for bb_idx in body.basic_blocks.indices() { + match &body.basic_blocks[bb_idx].terminator().kind { TerminatorKind::SwitchInt { - discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)), - ref targets, + discr: Operand::Copy(_) | Operand::Move(_), + targets, .. // We require that the possible target blocks don't contain this block. } if !targets.all_targets().contains(&bb_idx) => {} @@ -66,9 +64,10 @@ trait SimplifyMatch<'tcx> { typing_env: ty::TypingEnv<'tcx>, ) -> Option<()> { let bbs = &body.basic_blocks; - let (discr, targets) = match bbs[switch_bb_idx].terminator().kind { - TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets), - _ => unreachable!(), + let TerminatorKind::SwitchInt { discr, targets, .. } = + &bbs[switch_bb_idx].terminator().kind + else { + unreachable!(); }; let discr_ty = discr.ty(body.local_decls(), tcx); diff --git a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs index c63bfdcee8559..f59b849e85c62 100644 --- a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs +++ b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs @@ -18,19 +18,17 @@ impl<'tcx> crate::MirPass<'tcx> for MultipleReturnTerminators { // find basic blocks with no statement and a return terminator let mut bbs_simple_returns = DenseBitSet::new_empty(body.basic_blocks.len()); let bbs = body.basic_blocks_mut(); - for idx in bbs.indices() { - if bbs[idx].statements.is_empty() - && bbs[idx].terminator().kind == TerminatorKind::Return - { + for (idx, bb) in bbs.iter_enumerated() { + if bb.statements.is_empty() && bb.terminator().kind == TerminatorKind::Return { bbs_simple_returns.insert(idx); } } for bb in bbs { - if let TerminatorKind::Goto { target } = bb.terminator().kind { - if bbs_simple_returns.contains(target) { - bb.terminator_mut().kind = TerminatorKind::Return; - } + if let TerminatorKind::Goto { target } = bb.terminator().kind + && bbs_simple_returns.contains(target) + { + bb.terminator_mut().kind = TerminatorKind::Return; } } diff --git a/compiler/rustc_mir_transform/src/validate.rs b/compiler/rustc_mir_transform/src/validate.rs index e7930f0a1e3f6..66fe3ef4141f5 100644 --- a/compiler/rustc_mir_transform/src/validate.rs +++ b/compiler/rustc_mir_transform/src/validate.rs @@ -221,12 +221,11 @@ impl<'a, 'tcx> CfgChecker<'a, 'tcx> { // Check for cycles let mut stack = FxHashSet::default(); - for i in 0..parent.len() { - let mut bb = BasicBlock::from_usize(i); + for (mut bb, parent) in parent.iter_enumerated_mut() { stack.clear(); stack.insert(bb); loop { - let Some(parent) = parent[bb].take() else { break }; + let Some(parent) = parent.take() else { break }; let no_cycle = stack.insert(parent); if !no_cycle { self.fail(