Improve slice.binary_search_by()'s best-case performance to O(1) by Folyd · Pull Request #74024 · rust-lang/rust (original) (raw)

I don't see much difference.

test std_binary_search_l1               ... bench:          58 ns/iter (+/- 1)
test std_binary_search_l1_with_dups     ... bench:          57 ns/iter (+/- 2)
test std_binary_search_l1_worst_case    ... bench:          10 ns/iter (+/- 0)
test std_binary_search_l2               ... bench:          75 ns/iter (+/- 2)
test std_binary_search_l2_with_dups     ... bench:          75 ns/iter (+/- 7)
test std_binary_search_l2_worst_case    ... bench:          15 ns/iter (+/- 0)
test std_binary_search_l3               ... bench:         237 ns/iter (+/- 6)
test std_binary_search_l3_with_dups     ... bench:         238 ns/iter (+/- 6)
test std_binary_search_l3_worst_case    ... bench:          23 ns/iter (+/- 0)
test stdnew_binary_search_l1            ... bench:          58 ns/iter (+/- 4)
test stdnew_binary_search_l1_with_dups  ... bench:          57 ns/iter (+/- 2)
test stdnew_binary_search_l1_worst_case ... bench:          10 ns/iter (+/- 0)
test stdnew_binary_search_l2            ... bench:          75 ns/iter (+/- 3)
test stdnew_binary_search_l2_with_dups  ... bench:          75 ns/iter (+/- 2)
test stdnew_binary_search_l2_worst_case ... bench:          15 ns/iter (+/- 1)
test stdnew_binary_search_l3            ... bench:         238 ns/iter (+/- 9)
test stdnew_binary_search_l3_with_dups  ... bench:         238 ns/iter (+/- 5)
test stdnew_binary_search_l3_worst_case ... bench:          23 ns/iter (+/- 0)

benches/bench.rs

#![feature(test)] extern crate test;

use test::black_box; use test::Bencher;

use binary_search::*;

enum Cache { L1, L2, L3, }

fn std_bench_binary_search(b: &mut Bencher, cache: Cache, mapper: F) where F: Fn(usize) -> usize, { let size = match cache { Cache::L1 => 1000, // 8kb Cache::L2 => 10_000, // 80kb Cache::L3 => 1_000_000, // 8Mb }; let v = (0..size).map(&mapper).collect::<Vec<_>>(); let mut r = 0usize; b.iter(move || { // LCG constants from https://en.wikipedia.org/wiki/Numerical_Recipes. r = r.wrapping_mul(1664525).wrapping_add(1013904223); // Lookup the whole range to get 50% hits and 50% misses. let i = mapper(r % size); black_box(std_binary_search(&v, &i).is_ok()); }) }

fn std_bench_binary_search_worst_case(b: &mut Bencher, cache: Cache) { let size = match cache { Cache::L1 => 1000, // 8kb Cache::L2 => 10_000, // 80kb Cache::L3 => 1_000_000, // 8Mb }; let mut v = vec![0; size]; let i = 1; v[size - 1] = i; b.iter(move || { black_box(std_binary_search(&v, &i).is_ok()); }) }

#[bench] fn std_binary_search_l1(b: &mut Bencher) { std_bench_binary_search(b, Cache::L1, |i| i * 2); }

#[bench] fn std_binary_search_l2(b: &mut Bencher) { std_bench_binary_search(b, Cache::L2, |i| i * 2); }

#[bench] fn std_binary_search_l3(b: &mut Bencher) { std_bench_binary_search(b, Cache::L3, |i| i * 2); }

#[bench] fn std_binary_search_l1_with_dups(b: &mut Bencher) { std_bench_binary_search(b, Cache::L1, |i| i / 16 * 16); }

#[bench] fn std_binary_search_l2_with_dups(b: &mut Bencher) { std_bench_binary_search(b, Cache::L2, |i| i / 16 * 16); }

#[bench] fn std_binary_search_l3_with_dups(b: &mut Bencher) { std_bench_binary_search(b, Cache::L3, |i| i / 16 * 16); }

#[bench] fn std_binary_search_l1_worst_case(b: &mut Bencher) { std_bench_binary_search_worst_case(b, Cache::L1); }

#[bench] fn std_binary_search_l2_worst_case(b: &mut Bencher) { std_bench_binary_search_worst_case(b, Cache::L2); }

#[bench] fn std_binary_search_l3_worst_case(b: &mut Bencher) { std_bench_binary_search_worst_case(b, Cache::L3); }

fn stdnew_bench_binary_search(b: &mut Bencher, cache: Cache, mapper: F) where F: Fn(usize) -> usize, { let size = match cache { Cache::L1 => 1000, // 8kb Cache::L2 => 10_000, // 80kb Cache::L3 => 1_000_000, // 8Mb }; let v = (0..size).map(&mapper).collect::<Vec<_>>(); let mut r = 0usize; b.iter(move || { // LCG constants from https://en.wikipedia.org/wiki/Numerical_Recipes. r = r.wrapping_mul(1664525).wrapping_add(1013904223); // Lookup the whole range to get 50% hits and 50% misses. let i = mapper(r % size); black_box(stdnew_binary_search(&v, &i).is_ok()); }) }

fn stdnew_bench_binary_search_worst_case(b: &mut Bencher, cache: Cache) { let size = match cache { Cache::L1 => 1000, // 8kb Cache::L2 => 10_000, // 80kb Cache::L3 => 1_000_000, // 8Mb }; let mut v = vec![0; size]; let i = 1; v[size - 1] = i; b.iter(move || { black_box(stdnew_binary_search(&v, &i).is_ok()); }) }

#[bench] fn stdnew_binary_search_l1(b: &mut Bencher) { stdnew_bench_binary_search(b, Cache::L1, |i| i * 2); }

#[bench] fn stdnew_binary_search_l2(b: &mut Bencher) { stdnew_bench_binary_search(b, Cache::L2, |i| i * 2); }

#[bench] fn stdnew_binary_search_l3(b: &mut Bencher) { stdnew_bench_binary_search(b, Cache::L3, |i| i * 2); }

#[bench] fn stdnew_binary_search_l1_with_dups(b: &mut Bencher) { stdnew_bench_binary_search(b, Cache::L1, |i| i / 16 * 16); }

#[bench] fn stdnew_binary_search_l2_with_dups(b: &mut Bencher) { stdnew_bench_binary_search(b, Cache::L2, |i| i / 16 * 16); }

#[bench] fn stdnew_binary_search_l3_with_dups(b: &mut Bencher) { stdnew_bench_binary_search(b, Cache::L3, |i| i / 16 * 16); }

#[bench] fn stdnew_binary_search_l1_worst_case(b: &mut Bencher) { stdnew_bench_binary_search_worst_case(b, Cache::L1); }

#[bench] fn stdnew_binary_search_l2_worst_case(b: &mut Bencher) { stdnew_bench_binary_search_worst_case(b, Cache::L2); }

#[bench] fn stdnew_binary_search_l3_worst_case(b: &mut Bencher) { stdnew_bench_binary_search_worst_case(b, Cache::L3); }

src/lib.rs

use std::cmp::Ord; use std::cmp::Ordering::{self, Equal, Greater, Less};

pub fn std_binary_search(s: &[T], x: &T) -> Result<usize, usize> where T: Ord, { std_binary_search_by(s, |p| p.cmp(x)) }

pub fn std_binary_search_by<'a, T, F>(s: &'a [T], mut f: F) -> Result<usize, usize> where F: FnMut(&'a T) -> Ordering, { let mut size = s.len(); if size == 0 { return Err(0); } let mut base = 0usize; while size > 1 { let half = size / 2; let mid = base + half; // mid is always in [0, size), that means mid is >= 0 and < size. // mid >= 0: by definition // mid < size: mid = size / 2 + size / 4 + size / 8 ... let cmp = f(unsafe { s.get_unchecked(mid) }); base = if cmp == Greater { base } else { mid }; size -= half; } // base is always in [0, size) because base <= mid. let cmp = f(unsafe { s.get_unchecked(base) }); if cmp == Equal { Ok(base) } else { Err(base + (cmp == Less) as usize) } }

pub fn stdnew_binary_search(s: &[T], x: &T) -> Result<usize, usize> where T: Ord, { std_binary_search_by(s, |p| p.cmp(x)) }

pub fn stdnew_binary_search_by<'a, T, F>(s: &'a [T], mut f: F) -> Result<usize, usize> where F: FnMut(&'a T) -> Ordering, { let mut size = s.len(); if size == 0 { return Err(0); } let mut base = 0usize; while size > 1 { let half = size / 2; let mid = base + half; // mid is always in [0, size), that means mid is >= 0 and < size. // mid >= 0: by definition // mid < size: mid = size / 2 + size / 4 + size / 8 ... let cmp = f(unsafe { s.get_unchecked(mid) }); if cmp == Equal { return Ok(mid); } else if cmp == Less { base = mid } size -= half; } // base is always in [0, size) because base <= mid. let cmp = f(unsafe { s.get_unchecked(base) }); if cmp == Equal { Ok(base) } else { Err(base + (cmp == Less) as usize) } }