Auto merge of #123778 - jhorstmann:optimize-upper-lower-auto-vectoriz… · qinheping/verify-rust-std@7311aa8 (original) (raw)

`@@ -9,6 +9,7 @@

`

9

9

``

10

10

`use core::borrow::{Borrow, BorrowMut};

`

11

11

`use core::iter::FusedIterator;

`

``

12

`+

use core::mem::MaybeUninit;

`

12

13

`#[stable(feature = "encode_utf16", since = "1.8.0")]

`

13

14

`pub use core::str::EncodeUtf16;

`

14

15

`#[stable(feature = "split_ascii_whitespace", since = "1.34.0")]

`

`@@ -365,14 +366,9 @@ impl str {

`

365

366

` without modifying the original"]

`

366

367

`#[stable(feature = "unicode_case_mapping", since = "1.2.0")]

`

367

368

`pub fn to_lowercase(&self) -> String {

`

368

``

`-

let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);

`

``

369

`+

let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase);

`

369

370

``

370

``

`-

// Safety: we know this is a valid char boundary since

`

371

``

`-

// out.len() is only progressed if ascii bytes are found

`

372

``

`-

let rest = unsafe { self.get_unchecked(out.len()..) };

`

373

``

-

374

``

`-

// Safety: We have written only valid ASCII to our vec

`

375

``

`-

let mut s = unsafe { String::from_utf8_unchecked(out) };

`

``

371

`+

let prefix_len = s.len();

`

376

372

``

377

373

`for (i, c) in rest.char_indices() {

`

378

374

`if c == 'Σ' {

`

`@@ -381,8 +377,7 @@ impl str {

`

381

377

`` // in SpecialCasing.txt,

``

382

378

`// so hard-code it rather than have a generic "condition" mechanism.

`

383

379

`// See https://github.com/rust-lang/rust/issues/26035

`

384

``

`-

let out_len = self.len() - rest.len();

`

385

``

`-

let sigma_lowercase = map_uppercase_sigma(&self, i + out_len);

`

``

380

`+

let sigma_lowercase = map_uppercase_sigma(self, prefix_len + i);

`

386

381

` s.push(sigma_lowercase);

`

387

382

`} else {

`

388

383

`match conversions::to_lower(c) {

`

`@@ -458,14 +453,7 @@ impl str {

`

458

453

` without modifying the original"]

`

459

454

`#[stable(feature = "unicode_case_mapping", since = "1.2.0")]

`

460

455

`pub fn to_uppercase(&self) -> String {

`

461

``

`-

let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);

`

462

``

-

463

``

`-

// Safety: we know this is a valid char boundary since

`

464

``

`-

// out.len() is only progressed if ascii bytes are found

`

465

``

`-

let rest = unsafe { self.get_unchecked(out.len()..) };

`

466

``

-

467

``

`-

// Safety: We have written only valid ASCII to our vec

`

468

``

`-

let mut s = unsafe { String::from_utf8_unchecked(out) };

`

``

456

`+

let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase);

`

469

457

``

470

458

`for c in rest.chars() {

`

471

459

`match conversions::to_upper(c) {

`

`@@ -614,50 +602,87 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box {

`

614

602

`unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }

`

615

603

`}

`

616

604

``

617

``

`-

/// Converts the bytes while the bytes are still ascii.

`

``

605

`` +

/// Converts leading ascii bytes in s by calling the convert function.

``

``

606

`+

///

`

618

607

`` /// For better average performance, this happens in chunks of 2*size_of::<usize>().

``

619

``

`-

/// Returns a vec with the converted bytes.

`

``

608

`+

///

`

``

609

`+

/// Returns a tuple of the converted prefix and the remainder starting from

`

``

610

`+

/// the first non-ascii character.

`

``

611

`+

///

`

``

612

`+

/// This function is only public so that it can be verified in a codegen test,

`

``

613

`` +

/// see issue-123712-str-to-lower-autovectorization.rs.

``

``

614

`+

#[unstable(feature = "str_internals", issue = "none")]

`

``

615

`+

#[doc(hidden)]

`

620

616

`#[inline]

`

621

617

`#[cfg(not(test))]

`

622

618

`#[cfg(not(no_global_oom_handling))]

`

623

``

`-

fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec {

`

624

``

`-

let mut out = Vec::with_capacity(b.len());

`

``

619

`+

pub fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {

`

``

620

`+

// Process the input in chunks of 16 bytes to enable auto-vectorization.

`

``

621

`` +

// Previously the chunk size depended on the size of usize,

``

``

622

`+

// but on 32-bit platforms with sse or neon is also the better choice.

`

``

623

`+

// The only downside on other platforms would be a bit more loop-unrolling.

`

``

624

`+

const N: usize = 16;

`

``

625

+

``

626

`+

let mut slice = s.as_bytes();

`

``

627

`+

let mut out = Vec::with_capacity(slice.len());

`

``

628

`+

let mut out_slice = out.spare_capacity_mut();

`

``

629

+

``

630

`+

let mut ascii_prefix_len = 0_usize;

`

``

631

`+

let mut is_ascii = [false; N];

`

``

632

+

``

633

`+

while slice.len() >= N {

`

``

634

`+

// SAFETY: checked in loop condition

`

``

635

`+

let chunk = unsafe { slice.get_unchecked(..N) };

`

``

636

`+

// SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets

`

``

637

`+

let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };

`

``

638

+

``

639

`+

for j in 0..N {

`

``

640

`+

is_ascii[j] = chunk[j] <= 127;

`

``

641

`+

}

`

625

642

``

626

``

`-

const USIZE_SIZE: usize = mem::size_of::();

`

627

``

`-

const MAGIC_UNROLL: usize = 2;

`

628

``

`-

const N: usize = USIZE_SIZE * MAGIC_UNROLL;

`

629

``

`-

const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);

`

``

643

`+

// Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk

`

``

644

`+

// size gives the best result, specifically a pmovmsk instruction on x86.

`

``

645

`+

// See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not

`

``

646

`+

// currently recognize other similar idioms.

`

``

647

`+

if is_ascii.iter().map(|x| *x as u8).sum::() as usize != N {

`

``

648

`+

break;

`

``

649

`+

}

`

630

650

``

631

``

`-

let mut i = 0;

`

632

``

`-

unsafe {

`

633

``

`-

while i + N <= b.len() {

`

634

``

`` -

// Safety: we have checks the sizes b and out to know that our

``

635

``

`-

let in_chunk = b.get_unchecked(i..i + N);

`

636

``

`-

let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);

`

637

``

-

638

``

`-

let mut bits = 0;

`

639

``

`-

for j in 0..MAGIC_UNROLL {

`

640

``

`-

// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)

`

641

``

`-

// safety: in_chunk is valid bytes in the range

`

642

``

`-

bits |= in_chunk.as_ptr().cast::().add(j).read_unaligned();

`

643

``

`-

}

`

644

``

`-

// if our chunks aren't ascii, then return only the prior bytes as init

`

645

``

`-

if bits & NONASCII_MASK != 0 {

`

646

``

`-

break;

`

647

``

`-

}

`

``

651

`+

for j in 0..N {

`

``

652

`+

out_chunk[j] = MaybeUninit::new(convert(&chunk[j]));

`

``

653

`+

}

`

648

654

``

649

``

`-

// perform the case conversions on N bytes (gets heavily autovec'd)

`

650

``

`-

for j in 0..N {

`

651

``

`-

// safety: in_chunk and out_chunk is valid bytes in the range

`

652

``

`-

let out = out_chunk.get_unchecked_mut(j);

`

653

``

`-

out.write(convert(in_chunk.get_unchecked(j)));

`

654

``

`-

}

`

``

655

`+

ascii_prefix_len += N;

`

``

656

`+

slice = unsafe { slice.get_unchecked(N..) };

`

``

657

`+

out_slice = unsafe { out_slice.get_unchecked_mut(N..) };

`

``

658

`+

}

`

655

659

``

656

``

`-

// mark these bytes as initialised

`

657

``

`-

i += N;

`

``

660

`+

// handle the remainder as individual bytes

`

``

661

`+

while slice.len() > 0 {

`

``

662

`+

let byte = slice[0];

`

``

663

`+

if byte > 127 {

`

``

664

`+

break;

`

``

665

`+

}

`

``

666

`+

// SAFETY: out_slice has at least same length as input slice

`

``

667

`+

unsafe {

`

``

668

`+

*out_slice.get_unchecked_mut(0) = MaybeUninit::new(convert(&byte));

`

658

669

`}

`

659

``

`-

out.set_len(i);

`

``

670

`+

ascii_prefix_len += 1;

`

``

671

`+

slice = unsafe { slice.get_unchecked(1..) };

`

``

672

`+

out_slice = unsafe { out_slice.get_unchecked_mut(1..) };

`

660

673

`}

`

661

674

``

662

``

`-

out

`

``

675

`+

unsafe {

`

``

676

`+

// SAFETY: ascii_prefix_len bytes have been initialized above

`

``

677

`+

out.set_len(ascii_prefix_len);

`

``

678

+

``

679

`+

// SAFETY: We have written only valid ascii to the output vec

`

``

680

`+

let ascii_string = String::from_utf8_unchecked(out);

`

``

681

+

``

682

`+

// SAFETY: we know this is a valid char boundary

`

``

683

`+

// since we only skipped over leading ascii bytes

`

``

684

`+

let rest = core::str::from_utf8_unchecked(slice);

`

``

685

+

``

686

`+

(ascii_string, rest)

`

``

687

`+

}

`

663

688

`}

`