Shrink heapsort further by combining sift_down loops · qinheping/verify-rust-std@cd3d6e8 (original) (raw)
1
1
`//! This module contains a branchless heapsort as fallback for unstable quicksort.
`
2
2
``
3
``
`-
use crate::{intrinsics, ptr};
`
``
3
`+
use crate::{cmp, intrinsics, ptr};
`
4
4
``
5
5
`` /// Sorts v
using heapsort, which guarantees O(n * log(n)) worst-case.
``
6
6
`///
`
7
7
`` /// Never inline this, it sits the main hot-loop in recurse
and is meant as unlikely algorithmic
``
8
8
`/// fallback.
`
9
``
`-
///
`
10
``
`` -
/// SAFETY: The caller has to guarantee that v.len()
>= 2.
``
11
9
`#[inline(never)]
`
12
``
`-
pub(crate) unsafe fn heapsort<T, F>(v: &mut [T], is_less: &mut F)
`
``
10
`+
pub(crate) fn heapsort<T, F>(v: &mut [T], is_less: &mut F)
`
13
11
`where
`
14
12
`F: FnMut(&T, &T) -> bool,
`
15
13
`{
`
16
``
`-
// SAFETY: See function safety.
`
17
``
`-
unsafe {
`
18
``
`-
intrinsics::assume(v.len() >= 2);
`
19
``
-
20
``
`-
// Build the heap in linear time.
`
21
``
`-
for i in (0..v.len() / 2).rev() {
`
22
``
`-
sift_down(v, i, is_less);
`
23
``
`-
}
`
``
14
`+
let len = v.len();
`
24
15
``
25
``
`-
// Pop maximal elements from the heap.
`
26
``
`-
for i in (1..v.len()).rev() {
`
``
16
`+
for i in (0..len + len / 2).rev() {
`
``
17
`+
let sift_idx = if i >= len {
`
``
18
`+
i - len
`
``
19
`+
} else {
`
27
20
` v.swap(0, i);
`
28
``
`-
sift_down(&mut v[..i], 0, is_less);
`
``
21
`+
0
`
``
22
`+
};
`
``
23
+
``
24
`` +
// SAFETY: The above calculation ensures that sift_idx
is either 0 or
``
``
25
`` +
// (len..(len + (len / 2))) - len
, which simplifies to 0..(len / 2)
.
``
``
26
`` +
// This guarantees the required sift_idx <= len
.
``
``
27
`+
unsafe {
`
``
28
`+
sift_down(&mut v[..cmp::min(i, len)], sift_idx, is_less);
`
29
29
`}
`
30
30
`}
`
31
31
`}
`
32
32
``
33
33
`` // This binary heap respects the invariant parent >= child
.
``
34
34
`//
`
35
``
`` -
// SAFETY: The caller has to guarantee that node < v.len()
.
``
36
``
`-
#[inline(never)]
`
``
35
`` +
// SAFETY: The caller has to guarantee that node <= v.len()
.
``
``
36
`+
#[inline(always)]
`
37
37
`unsafe fn sift_down<T, F>(v: &mut [T], mut node: usize, is_less: &mut F)
`
38
38
`where
`
39
39
`F: FnMut(&T, &T) -> bool,
`
40
40
`{
`
41
41
`// SAFETY: See function safety.
`
42
42
`unsafe {
`
43
``
`-
intrinsics::assume(node < v.len());
`
``
43
`+
intrinsics::assume(node <= v.len());
`
44
44
`}
`
45
45
``
46
46
`let len = v.len();
`