Auto merge of #128254 - Amanieu:orig-binary-search, r=tgross35 · model-checking/verify-rust-std@e13d132 (original) (raw)

`@@ -7,7 +7,7 @@

`

7

7

`#![stable(feature = "rust1", since = "1.0.0")]

`

8

8

``

9

9

`use crate::cmp::Ordering::{self, Equal, Greater, Less};

`

10

``

`-

use crate::intrinsics::{exact_div, unchecked_sub};

`

``

10

`+

use crate::intrinsics::{exact_div, select_unpredictable, unchecked_sub};

`

11

11

`use crate::mem::{self, SizedTypeProperties};

`

12

12

`use crate::num::NonZero;

`

13

13

`use crate::ops::{Bound, OneSidedRange, Range, RangeBounds};

`

`@@ -2770,41 +2770,54 @@ impl [T] {

`

2770

2770

`where

`

2771

2771

`F: FnMut(&'a T) -> Ordering,

`

2772

2772

`{

`

2773

``

`-

// INVARIANTS:

`

2774

``

`-

// - 0 <= left <= left + size = right <= self.len()

`

2775

``

`-

// - f returns Less for everything in self[..left]

`

2776

``

`-

// - f returns Greater for everything in self[right..]

`

2777

2773

`let mut size = self.len();

`

2778

``

`-

let mut left = 0;

`

2779

``

`-

let mut right = size;

`

2780

``

`-

while left < right {

`

2781

``

`-

let mid = left + size / 2;

`

2782

``

-

2783

``

`` -

// SAFETY: the while condition means size is strictly positive, so

``

2784

``

`` -

// size/2 < size. Thus left + size/2 < left + size, which

``

2785

``

`` -

// coupled with the left + size <= self.len() invariant means

``

2786

``

`` -

// we have left + size/2 < self.len(), and this is in-bounds.

``

``

2774

`+

if size == 0 {

`

``

2775

`+

return Err(0);

`

``

2776

`+

}

`

``

2777

`+

let mut base = 0usize;

`

``

2778

+

``

2779

`+

// This loop intentionally doesn't have an early exit if the comparison

`

``

2780

`+

// returns Equal. We want the number of loop iterations to depend only

`

``

2781

`+

// on the size of the input slice so that the CPU can reliably predict

`

``

2782

`+

// the loop count.

`

``

2783

`+

while size > 1 {

`

``

2784

`+

let half = size / 2;

`

``

2785

`+

let mid = base + half;

`

``

2786

+

``

2787

`+

// SAFETY: the call is made safe by the following inconstants:

`

``

2788

`` +

// - mid >= 0: by definition

``

``

2789

`` +

// - mid < size: mid = size / 2 + size / 4 + size / 8 ...

``

2787

2790

`let cmp = f(unsafe { self.get_unchecked(mid) });

`

2788

2791

``

2789

``

`-

// This control flow produces conditional moves, which results in

`

2790

``

`-

// fewer branches and instructions than if/else or matching on

`

2791

``

`-

// cmp::Ordering.

`

2792

``

`-

// This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx.

`

2793

``

`-

left = if cmp == Less { mid + 1 } else { left };

`

2794

``

`-

right = if cmp == Greater { mid } else { right };

`

2795

``

`-

if cmp == Equal {

`

2796

``

`` -

// SAFETY: same as the get_unchecked above

``

2797

``

`-

unsafe { hint::assert_unchecked(mid < self.len()) };

`

2798

``

`-

return Ok(mid);

`

2799

``

`-

}

`

2800

``

-

2801

``

`-

size = right - left;

`

``

2792

`+

// Binary search interacts poorly with branch prediction, so force

`

``

2793

`+

// the compiler to use conditional moves if supported by the target

`

``

2794

`+

// architecture.

`

``

2795

`+

base = select_unpredictable(cmp == Greater, base, mid);

`

``

2796

+

``

2797

`` +

// This is imprecise in the case where size is odd and the

``

``

2798

`+

// comparison returns Greater: the mid element still gets included

`

``

2799

`` +

// by size even though it's known to be larger than the element

``

``

2800

`+

// being searched for.

`

``

2801

`+

//

`

``

2802

`+

// This is fine though: we gain more performance by keeping the

`

``

2803

`+

// loop iteration count invariant (and thus predictable) than we

`

``

2804

`+

// lose from considering one additional element.

`

``

2805

`+

size -= half;

`

2802

2806

`}

`

2803

2807

``

2804

``

`-

// SAFETY: directly true from the overall invariant.

`

2805

``

`` -

// Note that this is <=, unlike the assume in the Ok path.

``

2806

``

`-

unsafe { hint::assert_unchecked(left <= self.len()) };

`

2807

``

`-

Err(left)

`

``

2808

`+

// SAFETY: base is always in [0, size) because base <= mid.

`

``

2809

`+

let cmp = f(unsafe { self.get_unchecked(base) });

`

``

2810

`+

if cmp == Equal {

`

``

2811

`` +

// SAFETY: same as the get_unchecked above.

``

``

2812

`+

unsafe { hint::assert_unchecked(base < self.len()) };

`

``

2813

`+

Ok(base)

`

``

2814

`+

} else {

`

``

2815

`+

let result = base + (cmp == Less) as usize;

`

``

2816

`` +

// SAFETY: same as the get_unchecked above.

``

``

2817

`` +

// Note that this is <=, unlike the assume in the Ok path.

``

``

2818

`+

unsafe { hint::assert_unchecked(result <= self.len()) };

`

``

2819

`+

Err(result)

`

``

2820

`+

}

`

2808

2821

`}

`

2809

2822

``

2810

2823

`/// Binary searches this slice with a key extraction function.

`