Skip to content

Regarding isqrt performance #137786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
leonardo-m opened this issue Feb 28, 2025 · 16 comments
Open

Regarding isqrt performance #137786

leonardo-m opened this issue Feb 28, 2025 · 16 comments
Labels
C-discussion Category: Discussion or questions that doesn't represent real issues. T-libs Relevant to the library team, which will review and decide on the PR/issue.

Comments

@leonardo-m
Copy link

This is experimental code, it could be compiled with one or the other implementation of the integer square root:

#![feature(cfg_overflow_checks, stmt_expr_attributes)]
#![allow(unused_macros)]

#[cfg(overflow_checks)]
macro_rules! fto { ($e:expr, $t:ident) => (($e) as $t) }

#[cfg(not(overflow_checks))]
macro_rules! fto { ($e:expr, $t:ident) => (#[allow(unused_unsafe)] unsafe { $e.to_int_unchecked::<$t>() }) }

macro_rules! isqrt { ($e:expr, $t:ident) => ( fto!{(($e) as f64).sqrt(), $t} ) }

fn main() {
    let mut tot: u64 = 0;
    for x in (1_u64 .. 1_000_000_000).step_by(3) {
        tot += isqrt!(x, u64); // Version A.  1.04 seconds.
        //tot += x.isqrt(); // Version B.       4.94 seconds.
    }
    println!("{tot}");
}

As you see the floating-point based isqrt implementation is almost five times faster on my PC (using rustc 1.87.0-nightly). The std lib isqrt has upside of being const, so I can use it for const generics calculations and similar situations. And when I need to compute only one or few isqrt, this performance difference isn't important, and I use the std one. But when I need to compute a lot of isqrt, I consider faster alternatives.

While I don't propose to replace the std library isqrt with the code I've shown here, I suggest std lib implementers to express an opinion regarding the usage of floating point sqrt+ int cast in some vetted and safe cases.

@leonardo-m leonardo-m added the C-bug Category: This is a bug. label Feb 28, 2025
@rustbot rustbot added the needs-triage This issue may need triage. Remove it if it has been sufficiently triaged. label Feb 28, 2025
@bjorn3 bjorn3 added C-discussion Category: Discussion or questions that doesn't represent real issues. and removed C-bug Category: This is a bug. labels Feb 28, 2025
@jieyouxu jieyouxu added T-libs Relevant to the library team, which will review and decide on the PR/issue. and removed needs-triage This issue may need triage. Remove it if it has been sufficiently triaged. labels Feb 28, 2025
@CatsAreFluffy
Copy link

This implementation of u32::isqrt is float-free and significantly faster than the stdlib version:

// SQRTS[i<256] = (i << 8).isqrt() as u8
const fn new_isqrt(x: u32) -> u32 {
    if x < 256 {
        return SQRTS[x as usize] as u32 >> 4;
    }
    let idx = x >> ((25 - x.leading_zeros()) & !1);
    // SAFETY: If x has y leading zeros, the shift count is either 24 - y or
    // 25 - y. Thus idx has either 24 or 25 leading zeroes, and in particular
    // it's less than 256.
    unsafe { std::hint::assert_unchecked(idx < 256) };
    let approx1 = SQRTS[idx as usize] as u32;
    // SAFETY: Every element of SQRTS is at least 16 except element 0, and
    // idx is positive so that cannot be selected.
    unsafe { std::hint::assert_unchecked(approx1 >= 16) };
    let mut approx = (approx1 + 1) << ((25 - x.leading_zeros()) / 2) >> 4;
    approx = (approx + x / approx) / 2;
    if approx * approx > x {
        approx -= 1;
    }
    approx
}

Godbolt

On my M3 Max laptop, this implementation takes ~6.7 seconds to evaluate for all 2^32 inputs, while u32::isqrt takes ~15.7 seconds, and |x| (x as f64).sqrt() as u32 takes ~3.0 seconds. (I've included the test harness I used to get those numbers in the Godbolt if you want to try it for yourself.)

@cuviper
Copy link
Member

cuviper commented Mar 4, 2025

With my Ryzen 7 9800X3D on Linux, that's significantly slower at first:

New isqrt:
Time = 28.09259456s
f64 isqrt:
Time = 6.522072657s
Stdlib isqrt:
Time = 18.097873232s

That turns around with -Ctarget-cpu=native (zen5):

New isqrt:
Time = 6.881527916s
f64 isqrt:
Time = 6.457018049s
Stdlib isqrt:
Time = 17.119138915s

@ChaiTRex
Copy link
Contributor

ChaiTRex commented Mar 4, 2025

To avoid the SQRTS generation relying on the isqrt method (which may become circular if we replace more than the 32-bit isqrt methods), the following code can be used (checked to be identical with the code in the Godbolt above):

/// Fixed-point square roots of 0..=255, with 4 bits integer part and
/// 4 bits fractional part.
const FIXED_POINT_SQRTS: [u8; 256] = {
    let mut result = [0; 256];

    let mut sqrt = 0_u32;
    let mut i = 0_u32;
    while i < 256 {
        while sqrt * sqrt <= (i << 8) {
            sqrt += 1;
        }
        sqrt -= 1;

        result[i as usize] = sqrt as u8;
        i += 1;
    }

    result
};

@CatsAreFluffy
Copy link

CatsAreFluffy commented Mar 4, 2025

With my Ryzen 7 9800X3D on Linux, that's significantly slower at first:

This is caused by an issue with how leading_zeros is compiled on x86 and x86-64 targets not supporting BMI1. (Specifically, the bsr instruction, which is used to implement leading_zeros, has a dependency on its output register, and for new_isqrt, the register LLVM picks for that is also used in the validation code, so iterations of the loop are almost completely serialized, while for u32::isqrt, LLVM picks a register which isn't used for anything else, so there's approximately no speed penalty.) You could work around this by using inline assembly that properly breaks the dependency instead of leading_zeros, although I'm not sure if it would matter much outside of simple benchmarks like this.

@tgross35
Copy link
Contributor

tgross35 commented Mar 4, 2025

This is caused by an issue with how leading_zeros is compiled on x86 and x86-64 targets not supporting BMI1 (specifically, the bsr instruction, which is used to implement leading_zeros, has a dependency on its output register, and for new_isqrt, the register LLVM picks for that is also used in the validation code, so iterations of the loop are almost completely serialized, while for u32::isqrt, LLVM picks a register which isn't used for anything else, so there's approximately no speed penalty).

Does LLVM have enough information to make a better register selection, in theory? If so, this would be worth opening an issue (if one doesn't already exist).

@CatsAreFluffy
Copy link

I'm not sure how feasible it is to solve that by picking a better output register, but it's also solvable by initializing the destination register immediately before the bsr so the output dependency can't get stuck on anything else, and I've opened llvm/llvm-project#129659 for that.

@CatsAreFluffy
Copy link

CatsAreFluffy commented Mar 4, 2025

I've made a slightly faster isqrt implementation.

// SQRTS[i<256] = (i * 256).isqrt()
// RECIPS[i<192] = (1 << 38).div_ceil(SQRTS[i + 64] + 1)
fn newer_isqrt2(x: u32) -> u32 {
    if x < 256 {
        return SQRTS[x as usize] as u32 >> 4;
    }
    let idx = x >> ((25 - x.leading_zeros()) & !1);
    // SAFETY: If x has y leading zeros, the shift count is either 24 - y or
    // 25 - y. Thus idx has either 24 or 25 leading zeroes, and in particular
    // it's in 64..256.
    unsafe { std::hint::assert_unchecked(64 <= idx) };
    unsafe { std::hint::assert_unchecked(idx < 256) };
    let approx1 = SQRTS[idx as usize] as u32 + 1;
    let approx2 = approx1 << ((25 - x.leading_zeros()) / 2);
    let divmult = RECIPS[idx as usize - 64] as u64;
    // Approximately `x / approx2 * 16`.
    let approxr = (x as u64 * divmult >> (25 - x.leading_zeros()) / 2 + 30) as u32;
    let mut approx3 = approx2 + approxr >> 5;
    if approx3 * approx3 > x {
        approx3 -= 1;
    }
    approx3
}

Godbolt. (newer_isqrt2 is the faster one, newer_isqrt is similar but slower than new_isqrt.)

On my laptop, I get:

Newer isqrt 2:
Time = 6.589210709s
Newer isqrt:
Time = 7.007925959s
New isqrt:
Time = 6.910427791s
f64 isqrt:
Time = 4.150698459s
Stdlib isqrt:
Time = 14.867240041s

@CatsAreFluffy
Copy link

Here's a version which is (probably) a bit faster on 32-bit targets.

// SQRTS[i<256] = (i * 256).isqrt()
// RECIPS[i<192] = (1 << 39).div_ceil(SQRTS[i + 64] + 1)
fn new_isqrt_32bit(x: u32) -> u32 {
    if x < 256 {
        return SQRTS[x as usize] as u32 >> 4;
    }
    let idx = x >> ((25 - x.leading_zeros()) & !1);
    // SAFETY: If x has y leading zeros, the shift count is either 24 - y or
    // 25 - y. Thus idx has either 24 or 25 leading zeroes, and in particular
    // it's in 64..256.
    unsafe { std::hint::assert_unchecked(64 <= idx) };
    unsafe { std::hint::assert_unchecked(idx < 256) };
    let approx1 = SQRTS[idx as usize] as u32 + 1;
    let approx2 = approx1 << ((25 - x.leading_zeros()) / 2);
    let divmult = RECIPS[idx as usize - 64] as u64;
    // Approximately `x / approx2 * 16`.
    let approxr = ((x as u64 * divmult >> 32) as u32) >> (25 - x.leading_zeros()) / 2 - 1;
    let mut approx3 = approx2 + approxr >> 5;
    if approx3 * approx3 > x {
        approx3 -= 1;
    }
    approx3
}

Godbolt

@leonardo-m
Copy link
Author

Here's a version which is (probably) a bit faster on 32-bit targets.

The std lib could use a const icbrt as well :-) I've implemented it twice in my Rust coding, and one of them was buggy.

@ChaiTRex
Copy link
Contributor

I recently found out about the #[cfg(optimize_for_size)] attribute (#125011) that's usable in the standard library, which lets standard library developers provide two implementations of a method so that people compiling the standard library can choose between code optimized for speed and code that will run on (perhaps very) memory-constrained devices.

That issue says that some microcontrollers have only 16 kiB total memory, and since SQRTS and RECIPS together take a total of 1024 bytes, I think that we should consider having two implementations in the standard library. Are there any pretty fast implementations that have a small code + table size?

@CatsAreFluffy
Copy link

new_isqrt from my earlier posts uses only 256 bytes of lookup table.

@CatsAreFluffy
Copy link

This StackOverflow question has some nice implementations.

@ChaiTRex
Copy link
Contributor

This StackOverflow question has some nice implementations.

We need to be careful because if we start with some of that code and modify it (such as porting it to Rust and/or improving its speed), we'd be making a derivative work and be bound by its license. I believe that we try to keep everything in the Rust standard library under MIT or Apache 2.0, and I'm not sure whether Stack Overflow's chosen CC-BY-SA licenses are compatible with that.

One way around it would be if we decided which implementation(s) we wanted to work with, and then we got permission from the answerers in your linked post to license it under MIT and Apache 2.0. It appears both of the answerers in your linked post have been active on Stack Overflow in the past week. We'd need to follow the chain backwards in case they modified someone else's code (and that code wasn't under MIT or Apache 2.0) and so forth.

@ChaiTRex
Copy link
Contributor

It appears that the author of the longer answer there also was the original author of Python's math.isqrt and some improvements to it, so getting permission to use it might be simpler than I thought.

@leonardo-m
Copy link
Author

The std lib could use a const icbrt as well

A basic icbrt version:

// Integer cubic root for u32 values.
// (Don't use this to create a u64 version).
const fn icbrt_u32(x: u32) -> u32 {
    let mut x = x as u64;
    let mut y = 0;
    let mut s: i32 = 63;
    while s >= 0 {
        y += y;
        let b = 3 * y * (y + 1) + 1;
        if (x >> s) >= b {
            x -= b << s;
            y += 1;
        }
        s -= 3;
    }
    y as _
}

fn main() {
    for x in 0 ..= u32::MAX {
        if x % (1 << 24) == 0 {
            println!("{x}");
        }
        let c1 = icbrt_u32(x);
        let c2 = f64::from(x).cbrt() as u32;
        if c1 != c2 {
            println!("{x} {c1} {c2}");
        }
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
C-discussion Category: Discussion or questions that doesn't represent real issues. T-libs Relevant to the library team, which will review and decide on the PR/issue.
Projects
None yet
Development

No branches or pull requests

8 participants