Skip to content

Use more explicit and reliable ptr select in sort impls #136690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions library/core/src/slice/sort/shared/smallsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ unsafe fn swap_if_less<T, F>(v_base: *mut T, a_pos: usize, b_pos: usize, is_less
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: the caller must guarantee that `a` and `b` each added to `v_base` yield valid
// SAFETY: the caller must guarantee that `a_pos` and `b_pos` each added to `v_base` yield valid
// pointers into `v_base`, and are properly aligned, and part of the same allocation.
unsafe {
let v_a = v_base.add(a_pos);
Expand All @@ -404,16 +404,16 @@ where
// The equivalent code with a branch would be:
//
// if should_swap {
// ptr::swap(left, right, 1);
// ptr::swap(v_a, v_b, 1);
// }

// The goal is to generate cmov instructions here.
let left_swap = if should_swap { v_b } else { v_a };
let right_swap = if should_swap { v_a } else { v_b };
let v_a_swap = should_swap.select_unpredictable(v_b, v_a);
let v_b_swap = should_swap.select_unpredictable(v_a, v_b);

let right_swap_tmp = ManuallyDrop::new(ptr::read(right_swap));
ptr::copy(left_swap, v_a, 1);
ptr::copy_nonoverlapping(&*right_swap_tmp, v_b, 1);
let v_b_swap_tmp = ManuallyDrop::new(ptr::read(v_b_swap));
ptr::copy(v_a_swap, v_a, 1);
ptr::copy_nonoverlapping(&*v_b_swap_tmp, v_b, 1);
}
}

Expand Down Expand Up @@ -640,26 +640,21 @@ pub unsafe fn sort4_stable<T, F: FnMut(&T, &T) -> bool>(
// 1, 1 | c b a d
let c3 = is_less(&*c, &*a);
let c4 = is_less(&*d, &*b);
let min = select(c3, c, a);
let max = select(c4, b, d);
let unknown_left = select(c3, a, select(c4, c, b));
let unknown_right = select(c4, d, select(c3, b, c));
let min = c3.select_unpredictable(c, a);
let max = c4.select_unpredictable(b, d);
let unknown_left = c3.select_unpredictable(a, c4.select_unpredictable(c, b));
let unknown_right = c4.select_unpredictable(d, c3.select_unpredictable(b, c));

// Sort the last two unknown elements.
let c5 = is_less(&*unknown_right, &*unknown_left);
let lo = select(c5, unknown_right, unknown_left);
let hi = select(c5, unknown_left, unknown_right);
let lo = c5.select_unpredictable(unknown_right, unknown_left);
let hi = c5.select_unpredictable(unknown_left, unknown_right);

ptr::copy_nonoverlapping(min, dst, 1);
ptr::copy_nonoverlapping(lo, dst.add(1), 1);
ptr::copy_nonoverlapping(hi, dst.add(2), 1);
ptr::copy_nonoverlapping(max, dst.add(3), 1);
}

#[inline(always)]
fn select<T>(cond: bool, if_true: *const T, if_false: *const T) -> *const T {
if cond { if_true } else { if_false }
}
}

/// SAFETY: The caller MUST guarantee that `v_base` is valid for 8 reads and
Expand Down
Loading