Skip to content

Simplify switch sources #136959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions compiler/rustc_middle/src/mir/basic_blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,22 @@ pub struct BasicBlocks<'tcx> {
// Typically 95%+ of basic blocks have 4 or fewer predecessors.
type Predecessors = IndexVec<BasicBlock, SmallVec<[BasicBlock; 4]>>;

type SwitchSources = FxHashMap<(BasicBlock, BasicBlock), SmallVec<[Option<u128>; 1]>>;
/// Each `(target, switch)` entry in the map contains a list of switch values
/// that lead to a `target` block from a `switch` block.
///
/// Note: this type is currently never instantiated, because it's only used for
/// `BasicBlocks::switch_sources`, which is only called by backwards analyses
/// that do `SwitchInt` handling, and we don't have any of those, not even in
/// tests. See #95120 and #94576.
type SwitchSources = FxHashMap<(BasicBlock, BasicBlock), SmallVec<[SwitchTargetValue; 1]>>;

#[derive(Debug, Clone, Copy)]
pub enum SwitchTargetValue {
// A normal switch value.
Normal(u128),
// The final "otherwise" fallback value.
Otherwise,
}

#[derive(Clone, Default, Debug)]
struct Cache {
Expand Down Expand Up @@ -70,8 +85,8 @@ impl<'tcx> BasicBlocks<'tcx> {
})
}

/// `switch_sources()[&(target, switch)]` returns a list of switch
/// values that lead to a `target` block from a `switch` block.
/// Returns info about switch values that lead from one block to another
/// block. See `SwitchSources`.
#[inline]
pub fn switch_sources(&self) -> &SwitchSources {
self.cache.switch_sources.get_or_init(|| {
Expand All @@ -82,9 +97,15 @@ impl<'tcx> BasicBlocks<'tcx> {
}) = &data.terminator
{
for (value, target) in targets.iter() {
switch_sources.entry((target, bb)).or_default().push(Some(value));
switch_sources
.entry((target, bb))
.or_default()
.push(SwitchTargetValue::Normal(value));
}
switch_sources.entry((targets.otherwise(), bb)).or_default().push(None);
switch_sources
.entry((targets.otherwise(), bb))
.or_default()
.push(SwitchTargetValue::Otherwise);
}
}
switch_sources
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::fmt::{self, Debug, Formatter};
use std::ops::{Index, IndexMut};
use std::{iter, mem};

pub use basic_blocks::BasicBlocks;
pub use basic_blocks::{BasicBlocks, SwitchTargetValue};
use either::Either;
use polonius_engine::Atom;
use rustc_abi::{FieldIdx, VariantIdx};
Expand Down
30 changes: 19 additions & 11 deletions compiler/rustc_middle/src/mir/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,22 +1015,30 @@ impl TerminatorKind<'_> {

#[derive(Debug, Clone, TyEncodable, TyDecodable, Hash, HashStable, PartialEq)]
pub struct SwitchTargets {
/// Possible values. The locations to branch to in each case
/// are found in the corresponding indices from the `targets` vector.
/// Possible values. For each value, the location to branch to is found in
/// the corresponding element in the `targets` vector.
pub(super) values: SmallVec<[Pu128; 1]>,

/// Possible branch sites. The last element of this vector is used
/// for the otherwise branch, so targets.len() == values.len() + 1
/// should hold.
/// Possible branch targets. The last element of this vector is used for
/// the "otherwise" branch, so `targets.len() == values.len() + 1` always
/// holds.
//
// This invariant is quite non-obvious and also could be improved.
// One way to make this invariant is to have something like this instead:
// Note: This invariant is non-obvious and easy to violate. This would be a
// more rigorous representation:
//
// branches: Vec<(ConstInt, BasicBlock)>,
// otherwise: Option<BasicBlock> // exhaustive if None
// normal: SmallVec<[(Pu128, BasicBlock); 1]>,
// otherwise: BasicBlock,
//
// However we’ve decided to keep this as-is until we figure a case
// where some other approach seems to be strictly better than other.
// But it's important to have the targets in a sliceable type, because
// target slices show up elsewhere. E.g. `TerminatorKind::InlineAsm` has a
// boxed slice, and `TerminatorKind::FalseEdge` has a single target that
// can be converted to a slice with `slice::from_ref`.
//
// Why does this matter? In functions like `TerminatorKind::successors` we
// return `impl Iterator` and a non-slice-of-targets representation here
// causes problems because multiple different concrete iterator types would
// be involved and we would need a boxed trait object, which requires an
// allocation, which is expensive if done frequently.
pub(super) targets: SmallVec<[BasicBlock; 2]>,
}

Expand Down
27 changes: 11 additions & 16 deletions compiler/rustc_mir_dataflow/src/framework/direction.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::ops::RangeInclusive;

use rustc_middle::mir::{self, BasicBlock, CallReturnPlaces, Location, TerminatorEdges};
use rustc_middle::mir::{
self, BasicBlock, CallReturnPlaces, Location, SwitchTargetValue, TerminatorEdges,
};

use super::visitor::ResultsVisitor;
use super::{Analysis, Effect, EffectIndex, Results, SwitchIntTarget};
use super::{Analysis, Effect, EffectIndex, Results};

pub trait Direction {
const IS_FORWARD: bool;
Expand Down Expand Up @@ -112,14 +114,10 @@ impl Direction for Backward {

mir::TerminatorKind::SwitchInt { targets: _, ref discr } => {
if let Some(mut data) = analysis.get_switch_int_data(block, discr) {
let values = &body.basic_blocks.switch_sources()[&(block, pred)];
let targets =
values.iter().map(|&value| SwitchIntTarget { value, target: block });

let mut tmp = analysis.bottom_value(body);
for target in targets {
tmp.clone_from(&exit_state);
analysis.apply_switch_int_edge_effect(&mut data, &mut tmp, target);
for &value in &body.basic_blocks.switch_sources()[&(block, pred)] {
tmp.clone_from(exit_state);
analysis.apply_switch_int_edge_effect(&mut data, &mut tmp, value);
propagate(pred, &tmp);
}
} else {
Expand Down Expand Up @@ -292,12 +290,9 @@ impl Direction for Forward {
if let Some(mut data) = analysis.get_switch_int_data(block, discr) {
let mut tmp = analysis.bottom_value(body);
for (value, target) in targets.iter() {
tmp.clone_from(&exit_state);
analysis.apply_switch_int_edge_effect(
&mut data,
&mut tmp,
SwitchIntTarget { value: Some(value), target },
);
tmp.clone_from(exit_state);
let value = SwitchTargetValue::Normal(value);
analysis.apply_switch_int_edge_effect(&mut data, &mut tmp, value);
propagate(target, &tmp);
}

Expand All @@ -308,7 +303,7 @@ impl Direction for Forward {
analysis.apply_switch_int_edge_effect(
&mut data,
exit_state,
SwitchIntTarget { value: None, target: otherwise },
SwitchTargetValue::Otherwise,
);
propagate(otherwise, exit_state);
} else {
Expand Down
11 changes: 4 additions & 7 deletions compiler/rustc_mir_dataflow/src/framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ use rustc_data_structures::work_queue::WorkQueue;
use rustc_index::bit_set::{DenseBitSet, MixedBitSet};
use rustc_index::{Idx, IndexVec};
use rustc_middle::bug;
use rustc_middle::mir::{self, BasicBlock, CallReturnPlaces, Location, TerminatorEdges, traversal};
use rustc_middle::mir::{
self, BasicBlock, CallReturnPlaces, Location, SwitchTargetValue, TerminatorEdges, traversal,
};
use rustc_middle::ty::TyCtxt;
use tracing::error;

Expand Down Expand Up @@ -220,7 +222,7 @@ pub trait Analysis<'tcx> {
&mut self,
_data: &mut Self::SwitchIntData,
_state: &mut Self::Domain,
_edge: SwitchIntTarget,
_value: SwitchTargetValue,
) {
unreachable!();
}
Expand Down Expand Up @@ -430,10 +432,5 @@ impl EffectIndex {
}
}

pub struct SwitchIntTarget {
pub value: Option<u128>,
pub target: BasicBlock,
}

#[cfg(test)]
mod tests;
13 changes: 7 additions & 6 deletions compiler/rustc_mir_dataflow/src/impls/initialized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ use rustc_abi::VariantIdx;
use rustc_index::Idx;
use rustc_index::bit_set::{DenseBitSet, MixedBitSet};
use rustc_middle::bug;
use rustc_middle::mir::{self, Body, CallReturnPlaces, Location, TerminatorEdges};
use rustc_middle::mir::{
self, Body, CallReturnPlaces, Location, SwitchTargetValue, TerminatorEdges,
};
use rustc_middle::ty::util::Discr;
use rustc_middle::ty::{self, TyCtxt};
use tracing::{debug, instrument};

use crate::drop_flag_effects::DropFlagState;
use crate::framework::SwitchIntTarget;
use crate::move_paths::{HasMoveData, InitIndex, InitKind, LookupResult, MoveData, MovePathIndex};
use crate::{
Analysis, GenKill, MaybeReachable, drop_flag_effects, drop_flag_effects_for_function_entry,
Expand Down Expand Up @@ -422,9 +423,9 @@ impl<'tcx> Analysis<'tcx> for MaybeInitializedPlaces<'_, 'tcx> {
&mut self,
data: &mut Self::SwitchIntData,
state: &mut Self::Domain,
edge: SwitchIntTarget,
value: SwitchTargetValue,
) {
if let Some(value) = edge.value {
if let SwitchTargetValue::Normal(value) = value {
// Kill all move paths that correspond to variants we know to be inactive along this
// particular outgoing edge of a `SwitchInt`.
drop_flag_effects::on_all_inactive_variants(
Expand Down Expand Up @@ -535,9 +536,9 @@ impl<'tcx> Analysis<'tcx> for MaybeUninitializedPlaces<'_, 'tcx> {
&mut self,
data: &mut Self::SwitchIntData,
state: &mut Self::Domain,
edge: SwitchIntTarget,
value: SwitchTargetValue,
) {
if let Some(value) = edge.value {
if let SwitchTargetValue::Normal(value) = value {
// Mark all move paths that correspond to variants other than this one as maybe
// uninitialized (in reality, they are *definitely* uninitialized).
drop_flag_effects::on_all_inactive_variants(
Expand Down
Loading