Add auto-bitcasts between x86amx
and i32x256
for AMX intrinsics by sayantn · Pull Request #140763 · rust-lang/rust (original) (raw)
I tried adding support for AMX Tile types to Rust. They are very simple - the LLVM intrinsics operate on x86amx
types, and all that is needed for us to call those intrinsics is inserting bitcasts to/from x86amx
and i32x256
before and after the function call (as in this file in LLVM)
I tested the codegen for this fragmenttest.rs
#![feature( link_llvm_intrinsics, abi_unadjusted, x86_amx_intrinsics, repr_simd, simd_ffi )] #![allow(internal_features)] #![no_std]
#[repr(simd)] pub struct Tile([u32; 256]);
#[allow(improper_ctypes)] unsafe extern "unadjusted" { #[link_name = "llvm.x86.tdpbuud.internal"] fn tdpbuud(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile; }
#[unsafe(no_mangle)] #[target_feature(enable = "amx-int8")] pub fn foo(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile { unsafe { tdpbuud(m, n, k, a, b, c) } }
The LLVM IR generated is (output of rustc +stage1 --emit=llvm-ir --crate-type=rlib -O test.rs && cat test.ll
)
; ModuleID = 'test.acdeec3141bb4e39-cgu.0' source_filename = "test.acdeec3141bb4e39-cgu.0" target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-linux-gnu"
; Function Attrs: nonlazybind uwtable define void @foo(ptr sret([1024 x i8]) align 1024 %_0, i16 %m, i16 %n, i16 %k, ptr align 1024 %a, ptr align 1024 %b, ptr align 1024 %c) unnamed_addr #0 { start: %0 = load <256 x i32>, ptr %a, align 1024 %1 = load <256 x i32>, ptr %b, align 1024 %2 = load <256 x i32>, ptr %c, align 1024 %3 = bitcast <256 x i32> %0 to x86_amx %4 = bitcast <256 x i32> %1 to x86_amx %5 = bitcast <256 x i32> %2 to x86_amx %6 = call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx %3, x86_amx %4, x86_amx %5) #1 %7 = bitcast x86_amx %6 to <256 x i32> store <256 x i32> %7, ptr %_0, align 1024 ret void }
; Function Attrs: nounwind declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) unnamed_addr #1
attributes #0 = { nonlazybind uwtable "probe-stack"="inline-asm" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile" } attributes #1 = { nounwind }
!llvm.module.flags = !{!0, !1} !llvm.ident = !{!2}
!0 = !{i32 8, !"PIC Level", i32 2} !1 = !{i32 2, !"RtLibUseGOT", i32 1} !2 = !{!"rustc version 1.88.0-dev"}
and the ASM generated is (output of rustc +stage1 --emit=asm --crate-type=rlib -O test.rs && cat test.s
)
.file "test.acdeec3141bb4e39-cgu.0"
.section .text.foo,"ax",@progbits
.globl foo
.p2align 4
.type foo,@function
foo: .cfi_startproc movq %rdi, %rax xorps %xmm0, %xmm0 movups %xmm0, -64(%rsp) movups %xmm0, -48(%rsp) movups %xmm0, -32(%rsp) movups %xmm0, -16(%rsp) movb $1, -64(%rsp) movw %dx, -44(%rsp) movb %sil, -15(%rsp) movw %dx, -48(%rsp) movb %sil, -16(%rsp) movzwl %cx, %ecx movw %cx, -46(%rsp) movq 8(%rsp), %rdi movl %ecx, %r10d movb %r10b, -14(%rsp) shrl $2, %r10d movb %r10b, -14(%rsp) movl $64, %r11d ldtilecfg -64(%rsp) tileloadd (%r8,%r11), %tmm0 tileloadd (%r9,%r11), %tmm1 tileloadd (%rdi,%r11), %tmm2 tdpbuud %tmm2, %tmm1, %tmm0 tilestored %tmm0, (%rax,%r11) tilerelease retq .Lfunc_end0: .size foo, .Lfunc_end0-foo .cfi_endproc
.ident "rustc version 1.88.0-dev"
.section ".note.GNU-stack","",@progbits
(note: the tests were done on x86_64-unknown-linux-gnu
)
This is pretty similar to the CLang codegen (https://godbolt.org/z/G19rjo3Ke).
Reviews are welcome, as I am not too confident in the code (I am still not sure if the checks for AMX are strict enough, I will try strengthen them).
Unresolved Questions
Areturns outbitcast
's good enough? CLang usesllvm.x86.cast.vector.to.tile.v256i32
andllvm.x86.cast.tile.to.vector.v256i32
, is there any functional difference withbitcast
s?bitcast
can cause miscompilation (https://reviews.llvm.org/D99152), so we have to use the amx-specific casts- Should we allow only
i32x256
, or all vector types of size 8192? The LLVM file I referenced only does this fori32x256
, but there is really not reason to be restrictive.
@rustbot label O-x86_64 T-compiler
r? codegen