Auto merge of #123778 - jhorstmann:optimize-upper-lower-auto-vectoriz… · qinheping/verify-rust-std@7311aa8 (original) (raw)
`@@ -9,6 +9,7 @@
`
9
9
``
10
10
`use core::borrow::{Borrow, BorrowMut};
`
11
11
`use core::iter::FusedIterator;
`
``
12
`+
use core::mem::MaybeUninit;
`
12
13
`#[stable(feature = "encode_utf16", since = "1.8.0")]
`
13
14
`pub use core::str::EncodeUtf16;
`
14
15
`#[stable(feature = "split_ascii_whitespace", since = "1.34.0")]
`
`@@ -365,14 +366,9 @@ impl str {
`
365
366
` without modifying the original"]
`
366
367
`#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
`
367
368
`pub fn to_lowercase(&self) -> String {
`
368
``
`-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);
`
``
369
`+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase);
`
369
370
``
370
``
`-
// Safety: we know this is a valid char boundary since
`
371
``
`-
// out.len() is only progressed if ascii bytes are found
`
372
``
`-
let rest = unsafe { self.get_unchecked(out.len()..) };
`
373
``
-
374
``
`-
// Safety: We have written only valid ASCII to our vec
`
375
``
`-
let mut s = unsafe { String::from_utf8_unchecked(out) };
`
``
371
`+
let prefix_len = s.len();
`
376
372
``
377
373
`for (i, c) in rest.char_indices() {
`
378
374
`if c == 'Σ' {
`
`@@ -381,8 +377,7 @@ impl str {
`
381
377
`` // in SpecialCasing.txt
,
``
382
378
`// so hard-code it rather than have a generic "condition" mechanism.
`
383
379
`// See https://github.com/rust-lang/rust/issues/26035
`
384
``
`-
let out_len = self.len() - rest.len();
`
385
``
`-
let sigma_lowercase = map_uppercase_sigma(&self, i + out_len);
`
``
380
`+
let sigma_lowercase = map_uppercase_sigma(self, prefix_len + i);
`
386
381
` s.push(sigma_lowercase);
`
387
382
`} else {
`
388
383
`match conversions::to_lower(c) {
`
`@@ -458,14 +453,7 @@ impl str {
`
458
453
` without modifying the original"]
`
459
454
`#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
`
460
455
`pub fn to_uppercase(&self) -> String {
`
461
``
`-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);
`
462
``
-
463
``
`-
// Safety: we know this is a valid char boundary since
`
464
``
`-
// out.len() is only progressed if ascii bytes are found
`
465
``
`-
let rest = unsafe { self.get_unchecked(out.len()..) };
`
466
``
-
467
``
`-
// Safety: We have written only valid ASCII to our vec
`
468
``
`-
let mut s = unsafe { String::from_utf8_unchecked(out) };
`
``
456
`+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase);
`
469
457
``
470
458
`for c in rest.chars() {
`
471
459
`match conversions::to_upper(c) {
`
`@@ -614,50 +602,87 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box {
`
614
602
`unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
`
615
603
`}
`
616
604
``
617
``
`-
/// Converts the bytes while the bytes are still ascii.
`
``
605
`` +
/// Converts leading ascii bytes in s
by calling the convert
function.
``
``
606
`+
///
`
618
607
`` /// For better average performance, this happens in chunks of 2*size_of::<usize>()
.
``
619
``
`-
/// Returns a vec with the converted bytes.
`
``
608
`+
///
`
``
609
`+
/// Returns a tuple of the converted prefix and the remainder starting from
`
``
610
`+
/// the first non-ascii character.
`
``
611
`+
///
`
``
612
`+
/// This function is only public so that it can be verified in a codegen test,
`
``
613
`` +
/// see issue-123712-str-to-lower-autovectorization.rs
.
``
``
614
`+
#[unstable(feature = "str_internals", issue = "none")]
`
``
615
`+
#[doc(hidden)]
`
620
616
`#[inline]
`
621
617
`#[cfg(not(test))]
`
622
618
`#[cfg(not(no_global_oom_handling))]
`
623
``
`-
fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec {
`
624
``
`-
let mut out = Vec::with_capacity(b.len());
`
``
619
`+
pub fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
`
``
620
`+
// Process the input in chunks of 16 bytes to enable auto-vectorization.
`
``
621
`` +
// Previously the chunk size depended on the size of usize
,
``
``
622
`+
// but on 32-bit platforms with sse or neon is also the better choice.
`
``
623
`+
// The only downside on other platforms would be a bit more loop-unrolling.
`
``
624
`+
const N: usize = 16;
`
``
625
+
``
626
`+
let mut slice = s.as_bytes();
`
``
627
`+
let mut out = Vec::with_capacity(slice.len());
`
``
628
`+
let mut out_slice = out.spare_capacity_mut();
`
``
629
+
``
630
`+
let mut ascii_prefix_len = 0_usize;
`
``
631
`+
let mut is_ascii = [false; N];
`
``
632
+
``
633
`+
while slice.len() >= N {
`
``
634
`+
// SAFETY: checked in loop condition
`
``
635
`+
let chunk = unsafe { slice.get_unchecked(..N) };
`
``
636
`+
// SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets
`
``
637
`+
let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };
`
``
638
+
``
639
`+
for j in 0..N {
`
``
640
`+
is_ascii[j] = chunk[j] <= 127;
`
``
641
`+
}
`
625
642
``
626
``
`-
const USIZE_SIZE: usize = mem::size_of::();
`
627
``
`-
const MAGIC_UNROLL: usize = 2;
`
628
``
`-
const N: usize = USIZE_SIZE * MAGIC_UNROLL;
`
629
``
`-
const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);
`
``
643
`+
// Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk
`
``
644
`+
// size gives the best result, specifically a pmovmsk instruction on x86.
`
``
645
`+
// See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not
`
``
646
`+
// currently recognize other similar idioms.
`
``
647
`+
if is_ascii.iter().map(|x| *x as u8).sum::() as usize != N {
`
``
648
`+
break;
`
``
649
`+
}
`
630
650
``
631
``
`-
let mut i = 0;
`
632
``
`-
unsafe {
`
633
``
`-
while i + N <= b.len() {
`
634
``
`` -
// Safety: we have checks the sizes b
and out
to know that our
``
635
``
`-
let in_chunk = b.get_unchecked(i..i + N);
`
636
``
`-
let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);
`
637
``
-
638
``
`-
let mut bits = 0;
`
639
``
`-
for j in 0..MAGIC_UNROLL {
`
640
``
`-
// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
`
641
``
`-
// safety: in_chunk is valid bytes in the range
`
642
``
`-
bits |= in_chunk.as_ptr().cast::().add(j).read_unaligned();
`
643
``
`-
}
`
644
``
`-
// if our chunks aren't ascii, then return only the prior bytes as init
`
645
``
`-
if bits & NONASCII_MASK != 0 {
`
646
``
`-
break;
`
647
``
`-
}
`
``
651
`+
for j in 0..N {
`
``
652
`+
out_chunk[j] = MaybeUninit::new(convert(&chunk[j]));
`
``
653
`+
}
`
648
654
``
649
``
`-
// perform the case conversions on N bytes (gets heavily autovec'd)
`
650
``
`-
for j in 0..N {
`
651
``
`-
// safety: in_chunk and out_chunk is valid bytes in the range
`
652
``
`-
let out = out_chunk.get_unchecked_mut(j);
`
653
``
`-
out.write(convert(in_chunk.get_unchecked(j)));
`
654
``
`-
}
`
``
655
`+
ascii_prefix_len += N;
`
``
656
`+
slice = unsafe { slice.get_unchecked(N..) };
`
``
657
`+
out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
`
``
658
`+
}
`
655
659
``
656
``
`-
// mark these bytes as initialised
`
657
``
`-
i += N;
`
``
660
`+
// handle the remainder as individual bytes
`
``
661
`+
while slice.len() > 0 {
`
``
662
`+
let byte = slice[0];
`
``
663
`+
if byte > 127 {
`
``
664
`+
break;
`
``
665
`+
}
`
``
666
`+
// SAFETY: out_slice has at least same length as input slice
`
``
667
`+
unsafe {
`
``
668
`+
*out_slice.get_unchecked_mut(0) = MaybeUninit::new(convert(&byte));
`
658
669
`}
`
659
``
`-
out.set_len(i);
`
``
670
`+
ascii_prefix_len += 1;
`
``
671
`+
slice = unsafe { slice.get_unchecked(1..) };
`
``
672
`+
out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
`
660
673
`}
`
661
674
``
662
``
`-
out
`
``
675
`+
unsafe {
`
``
676
`+
// SAFETY: ascii_prefix_len bytes have been initialized above
`
``
677
`+
out.set_len(ascii_prefix_len);
`
``
678
+
``
679
`+
// SAFETY: We have written only valid ascii to the output vec
`
``
680
`+
let ascii_string = String::from_utf8_unchecked(out);
`
``
681
+
``
682
`+
// SAFETY: we know this is a valid char boundary
`
``
683
`+
// since we only skipped over leading ascii bytes
`
``
684
`+
let rest = core::str::from_utf8_unchecked(slice);
`
``
685
+
``
686
`+
(ascii_string, rest)
`
``
687
`+
}
`
663
688
`}
`