Auto merge of #128254 - Amanieu:orig-binary-search, r=tgross35 · model-checking/verify-rust-std@e13d132 (original) (raw)
`@@ -7,7 +7,7 @@
`
7
7
`#![stable(feature = "rust1", since = "1.0.0")]
`
8
8
``
9
9
`use crate::cmp::Ordering::{self, Equal, Greater, Less};
`
10
``
`-
use crate::intrinsics::{exact_div, unchecked_sub};
`
``
10
`+
use crate::intrinsics::{exact_div, select_unpredictable, unchecked_sub};
`
11
11
`use crate::mem::{self, SizedTypeProperties};
`
12
12
`use crate::num::NonZero;
`
13
13
`use crate::ops::{Bound, OneSidedRange, Range, RangeBounds};
`
`@@ -2770,41 +2770,54 @@ impl [T] {
`
2770
2770
`where
`
2771
2771
`F: FnMut(&'a T) -> Ordering,
`
2772
2772
`{
`
2773
``
`-
// INVARIANTS:
`
2774
``
`-
// - 0 <= left <= left + size = right <= self.len()
`
2775
``
`-
// - f returns Less for everything in self[..left]
`
2776
``
`-
// - f returns Greater for everything in self[right..]
`
2777
2773
`let mut size = self.len();
`
2778
``
`-
let mut left = 0;
`
2779
``
`-
let mut right = size;
`
2780
``
`-
while left < right {
`
2781
``
`-
let mid = left + size / 2;
`
2782
``
-
2783
``
`` -
// SAFETY: the while condition means size
is strictly positive, so
``
2784
``
`` -
// size/2 < size
. Thus left + size/2 < left + size
, which
``
2785
``
`` -
// coupled with the left + size <= self.len()
invariant means
``
2786
``
`` -
// we have left + size/2 < self.len()
, and this is in-bounds.
``
``
2774
`+
if size == 0 {
`
``
2775
`+
return Err(0);
`
``
2776
`+
}
`
``
2777
`+
let mut base = 0usize;
`
``
2778
+
``
2779
`+
// This loop intentionally doesn't have an early exit if the comparison
`
``
2780
`+
// returns Equal. We want the number of loop iterations to depend only
`
``
2781
`+
// on the size of the input slice so that the CPU can reliably predict
`
``
2782
`+
// the loop count.
`
``
2783
`+
while size > 1 {
`
``
2784
`+
let half = size / 2;
`
``
2785
`+
let mid = base + half;
`
``
2786
+
``
2787
`+
// SAFETY: the call is made safe by the following inconstants:
`
``
2788
`` +
// - mid >= 0
: by definition
``
``
2789
`` +
// - mid < size
: mid = size / 2 + size / 4 + size / 8 ...
``
2787
2790
`let cmp = f(unsafe { self.get_unchecked(mid) });
`
2788
2791
``
2789
``
`-
// This control flow produces conditional moves, which results in
`
2790
``
`-
// fewer branches and instructions than if/else or matching on
`
2791
``
`-
// cmp::Ordering.
`
2792
``
`-
// This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx.
`
2793
``
`-
left = if cmp == Less { mid + 1 } else { left };
`
2794
``
`-
right = if cmp == Greater { mid } else { right };
`
2795
``
`-
if cmp == Equal {
`
2796
``
`` -
// SAFETY: same as the get_unchecked
above
``
2797
``
`-
unsafe { hint::assert_unchecked(mid < self.len()) };
`
2798
``
`-
return Ok(mid);
`
2799
``
`-
}
`
2800
``
-
2801
``
`-
size = right - left;
`
``
2792
`+
// Binary search interacts poorly with branch prediction, so force
`
``
2793
`+
// the compiler to use conditional moves if supported by the target
`
``
2794
`+
// architecture.
`
``
2795
`+
base = select_unpredictable(cmp == Greater, base, mid);
`
``
2796
+
``
2797
`` +
// This is imprecise in the case where size
is odd and the
``
``
2798
`+
// comparison returns Greater: the mid element still gets included
`
``
2799
`` +
// by size
even though it's known to be larger than the element
``
``
2800
`+
// being searched for.
`
``
2801
`+
//
`
``
2802
`+
// This is fine though: we gain more performance by keeping the
`
``
2803
`+
// loop iteration count invariant (and thus predictable) than we
`
``
2804
`+
// lose from considering one additional element.
`
``
2805
`+
size -= half;
`
2802
2806
`}
`
2803
2807
``
2804
``
`-
// SAFETY: directly true from the overall invariant.
`
2805
``
`` -
// Note that this is <=
, unlike the assume in the Ok
path.
``
2806
``
`-
unsafe { hint::assert_unchecked(left <= self.len()) };
`
2807
``
`-
Err(left)
`
``
2808
`+
// SAFETY: base is always in [0, size) because base <= mid.
`
``
2809
`+
let cmp = f(unsafe { self.get_unchecked(base) });
`
``
2810
`+
if cmp == Equal {
`
``
2811
`` +
// SAFETY: same as the get_unchecked
above.
``
``
2812
`+
unsafe { hint::assert_unchecked(base < self.len()) };
`
``
2813
`+
Ok(base)
`
``
2814
`+
} else {
`
``
2815
`+
let result = base + (cmp == Less) as usize;
`
``
2816
`` +
// SAFETY: same as the get_unchecked
above.
``
``
2817
`` +
// Note that this is <=
, unlike the assume in the Ok
path.
``
``
2818
`+
unsafe { hint::assert_unchecked(result <= self.len()) };
`
``
2819
`+
Err(result)
`
``
2820
`+
}
`
2808
2821
`}
`
2809
2822
``
2810
2823
`/// Binary searches this slice with a key extraction function.
`