core/stdarch/crates/core_arch/src/x86_64/
amx.rs

1#[cfg(test)]
2use stdarch_test::assert_instr;
3
4/// Load tile configuration from a 64-byte memory location specified by mem_addr.
5/// The tile configuration format is specified below, and includes the tile type pallette,
6/// the number of bytes per row, and the number of rows. If the specified pallette_id is zero,
7/// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed.
8/// Any invalid configurations will result in #GP fault.
9///
10/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875)
11#[inline]
12#[target_feature(enable = "amx-tile")]
13#[cfg_attr(test, assert_instr(ldtilecfg))]
14#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
15pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
16    ldtilecfg(mem_addr);
17}
18
19/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr.
20/// The tile configuration format is specified below, and includes the tile type pallette,
21/// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory.
22///
23/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879)
24#[inline]
25#[target_feature(enable = "amx-tile")]
26#[cfg_attr(test, assert_instr(sttilecfg))]
27#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
28pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
29    sttilecfg(mem_addr);
30}
31
32/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig.
33///
34/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877)
35#[inline]
36#[rustc_legacy_const_generics(0)]
37#[target_feature(enable = "amx-tile")]
38#[cfg_attr(test, assert_instr(tileloadd, DST = 0))]
39#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
40pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) {
41    static_assert_uimm_bits!(DST, 3);
42    tileloadd64(DST as i8, base, stride);
43}
44
45/// Release the tile configuration to return to the init state, which releases all storage it currently holds.
46///
47/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878)
48#[inline]
49#[target_feature(enable = "amx-tile")]
50#[cfg_attr(test, assert_instr(tilerelease))]
51#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
52pub unsafe fn _tile_release() {
53    tilerelease();
54}
55
56/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig.
57///
58/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881)
59#[inline]
60#[rustc_legacy_const_generics(0)]
61#[target_feature(enable = "amx-tile")]
62#[cfg_attr(test, assert_instr(tilestored, DST = 0))]
63#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
64pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) {
65    static_assert_uimm_bits!(DST, 3);
66    tilestored64(DST as i8, base, stride);
67}
68
69/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration
70/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will
71/// likely not be reused in the near future and the data caching can be optimized accordingly.
72///
73/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883)
74#[inline]
75#[rustc_legacy_const_generics(0)]
76#[target_feature(enable = "amx-tile")]
77#[cfg_attr(test, assert_instr(tileloaddt1, DST = 0))]
78#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
79pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) {
80    static_assert_uimm_bits!(DST, 3);
81    tileloaddt164(DST as i8, base, stride);
82}
83
84/// Zero the tile specified by tdest.
85///
86/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885)
87#[inline]
88#[rustc_legacy_const_generics(0)]
89#[target_feature(enable = "amx-tile")]
90#[cfg_attr(test, assert_instr(tilezero, DST = 0))]
91#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
92pub unsafe fn _tile_zero<const DST: i32>() {
93    static_assert_uimm_bits!(DST, 3);
94    tilezero(DST as i8);
95}
96
97/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b,
98/// accumulating the intermediate single-precision (32-bit) floating-point elements
99/// with elements in dst, and store the 32-bit result back to tile dst.
100///
101/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864)
102#[inline]
103#[rustc_legacy_const_generics(0, 1, 2)]
104#[target_feature(enable = "amx-bf16")]
105#[cfg_attr(test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))]
106#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
107pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() {
108    static_assert_uimm_bits!(DST, 3);
109    static_assert_uimm_bits!(A, 3);
110    static_assert_uimm_bits!(B, 3);
111    tdpbf16ps(DST as i8, A as i8, B as i8);
112}
113
114/// Compute dot-product of bytes in tiles with a source/destination accumulator.
115/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
116/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
117/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
118///
119/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866)
120#[inline]
121#[rustc_legacy_const_generics(0, 1, 2)]
122#[target_feature(enable = "amx-int8")]
123#[cfg_attr(test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))]
124#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
125pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() {
126    static_assert_uimm_bits!(DST, 3);
127    static_assert_uimm_bits!(A, 3);
128    static_assert_uimm_bits!(B, 3);
129    tdpbssd(DST as i8, A as i8, B as i8);
130}
131
132/// Compute dot-product of bytes in tiles with a source/destination accumulator.
133/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
134/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
135/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
136///
137/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868)
138#[inline]
139#[rustc_legacy_const_generics(0, 1, 2)]
140#[target_feature(enable = "amx-int8")]
141#[cfg_attr(test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))]
142#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
143pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() {
144    static_assert_uimm_bits!(DST, 3);
145    static_assert_uimm_bits!(A, 3);
146    static_assert_uimm_bits!(B, 3);
147    tdpbsud(DST as i8, A as i8, B as i8);
148}
149
150/// Compute dot-product of bytes in tiles with a source/destination accumulator.
151/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
152/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
153/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
154///
155/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870)
156#[inline]
157#[rustc_legacy_const_generics(0, 1, 2)]
158#[target_feature(enable = "amx-int8")]
159#[cfg_attr(test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))]
160#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
161pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() {
162    static_assert_uimm_bits!(DST, 3);
163    static_assert_uimm_bits!(A, 3);
164    static_assert_uimm_bits!(B, 3);
165    tdpbusd(DST as i8, A as i8, B as i8);
166}
167
168/// Compute dot-product of bytes in tiles with a source/destination accumulator.
169/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
170/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
171/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
172///
173/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872)
174#[inline]
175#[rustc_legacy_const_generics(0, 1, 2)]
176#[target_feature(enable = "amx-int8")]
177#[cfg_attr(test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))]
178#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
179pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() {
180    static_assert_uimm_bits!(DST, 3);
181    static_assert_uimm_bits!(A, 3);
182    static_assert_uimm_bits!(B, 3);
183    tdpbuud(DST as i8, A as i8, B as i8);
184}
185
186/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
187/// accumulating the intermediate single-precision (32-bit) floating-point elements
188///  with elements in dst, and store the 32-bit result back to tile dst.
189///
190/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874)
191#[inline]
192#[rustc_legacy_const_generics(0, 1, 2)]
193#[target_feature(enable = "amx-fp16")]
194#[cfg_attr(test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))]
195#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
196pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() {
197    static_assert_uimm_bits!(DST, 3);
198    static_assert_uimm_bits!(A, 3);
199    static_assert_uimm_bits!(B, 3);
200    tdpfp16ps(DST as i8, A as i8, B as i8);
201}
202
203/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
204/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
205/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b),
206/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
207/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of
208/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added,
209/// and then accumulated into the corresponding row and column of dst.
210///
211/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860)
212#[inline]
213#[rustc_legacy_const_generics(0, 1, 2)]
214#[target_feature(enable = "amx-complex")]
215#[cfg_attr(test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))]
216#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
217pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() {
218    static_assert_uimm_bits!(DST, 3);
219    static_assert_uimm_bits!(A, 3);
220    static_assert_uimm_bits!(B, 3);
221    tcmmimfp16ps(DST as i8, A as i8, B as i8);
222}
223
224/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
225/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
226/// Calculates the real part of the result. For each possible combination of (row of a, column of b),
227/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
228/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of
229/// the a element is multiplied with the imaginary part of the corresponding b elements.
230/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst.
231///
232/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862)
233#[inline]
234#[rustc_legacy_const_generics(0, 1, 2)]
235#[target_feature(enable = "amx-complex")]
236#[cfg_attr(test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))]
237#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
238pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
239    static_assert_uimm_bits!(DST, 3);
240    static_assert_uimm_bits!(A, 3);
241    static_assert_uimm_bits!(B, 3);
242    tcmmrlfp16ps(DST as i8, A as i8, B as i8);
243}
244
245#[allow(improper_ctypes)]
246unsafe extern "C" {
247    #[link_name = "llvm.x86.ldtilecfg"]
248    fn ldtilecfg(mem_addr: *const u8);
249    #[link_name = "llvm.x86.sttilecfg"]
250    fn sttilecfg(mem_addr: *mut u8);
251    #[link_name = "llvm.x86.tileloadd64"]
252    fn tileloadd64(dst: i8, base: *const u8, stride: usize);
253    #[link_name = "llvm.x86.tileloaddt164"]
254    fn tileloaddt164(dst: i8, base: *const u8, stride: usize);
255    #[link_name = "llvm.x86.tilerelease"]
256    fn tilerelease();
257    #[link_name = "llvm.x86.tilestored64"]
258    fn tilestored64(dst: i8, base: *mut u8, stride: usize);
259    #[link_name = "llvm.x86.tilezero"]
260    fn tilezero(dst: i8);
261    #[link_name = "llvm.x86.tdpbf16ps"]
262    fn tdpbf16ps(dst: i8, a: i8, b: i8);
263    #[link_name = "llvm.x86.tdpbuud"]
264    fn tdpbuud(dst: i8, a: i8, b: i8);
265    #[link_name = "llvm.x86.tdpbusd"]
266    fn tdpbusd(dst: i8, a: i8, b: i8);
267    #[link_name = "llvm.x86.tdpbsud"]
268    fn tdpbsud(dst: i8, a: i8, b: i8);
269    #[link_name = "llvm.x86.tdpbssd"]
270    fn tdpbssd(dst: i8, a: i8, b: i8);
271    #[link_name = "llvm.x86.tdpfp16ps"]
272    fn tdpfp16ps(dst: i8, a: i8, b: i8);
273    #[link_name = "llvm.x86.tcmmimfp16ps"]
274    fn tcmmimfp16ps(dst: i8, a: i8, b: i8);
275    #[link_name = "llvm.x86.tcmmrlfp16ps"]
276    fn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
277}
278
279#[cfg(test)]
280mod tests {
281    use crate::core_arch::x86::_mm_cvtness_sbh;
282    use crate::core_arch::x86_64::*;
283    use core::mem::transmute;
284    use stdarch_test::simd_test;
285    #[cfg(target_os = "linux")]
286    use syscalls::{Sysno, syscall};
287
288    #[allow(non_camel_case_types)]
289    #[repr(packed)]
290    #[derive(Copy, Clone, Default, Debug, PartialEq)]
291    struct __tilecfg {
292        /// 0 `or` 1
293        palette: u8,
294        start_row: u8,
295        /// reserved, must be zero
296        reserved_a0: [u8; 14],
297        /// number of bytes of one row in each tile
298        colsb: [u16; 8],
299        /// reserved, must be zero
300        reserved_b0: [u16; 8],
301        /// number of rows in each tile
302        rows: [u8; 8],
303        /// reserved, must be zero
304        reserved_c0: [u8; 8],
305    }
306
307    impl __tilecfg {
308        fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
309            Self {
310                palette,
311                start_row,
312                reserved_a0: [0u8; 14],
313                colsb,
314                reserved_b0: [0u16; 8],
315                rows,
316                reserved_c0: [0u8; 8],
317            }
318        }
319
320        const fn as_ptr(&self) -> *const u8 {
321            self as *const Self as *const u8
322        }
323
324        fn as_mut_ptr(&mut self) -> *mut u8 {
325            self as *mut Self as *mut u8
326        }
327    }
328
329    #[cfg(not(target_os = "linux"))]
330    #[target_feature(enable = "amx-tile")]
331    fn _init_amx() {}
332
333    #[cfg(target_os = "linux")]
334    #[target_feature(enable = "amx-tile")]
335    #[inline]
336    unsafe fn _init_amx() {
337        let mut ret: usize;
338        let mut xfeatures: usize = 0;
339        ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
340            .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
341        if ret != 0 {
342            panic!("Failed to get XFEATURES");
343        } else {
344            match 0b11 & (xfeatures >> 17) {
345                0 => panic!("AMX is not available"),
346                1 => {
347                    ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
348                        .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
349                    if ret != 0 {
350                        panic!("Failed to enable AMX");
351                    }
352                }
353                3 => {}
354                _ => unreachable!(),
355            }
356        }
357    }
358
359    #[simd_test(enable = "amx-tile")]
360    unsafe fn test_tile_loadconfig() {
361        let config = __tilecfg::default();
362        _tile_loadconfig(config.as_ptr());
363        _tile_release();
364    }
365
366    #[simd_test(enable = "amx-tile")]
367    unsafe fn test_tile_storeconfig() {
368        let config = __tilecfg::new(1, 0, [32; 8], [8; 8]);
369        _tile_loadconfig(config.as_ptr());
370        let mut _config = __tilecfg::default();
371        _tile_storeconfig(_config.as_mut_ptr());
372        _tile_release();
373        assert_eq!(config, _config);
374    }
375
376    #[simd_test(enable = "amx-tile")]
377    unsafe fn test_tile_zero() {
378        _init_amx();
379        let mut config = __tilecfg::default();
380        config.palette = 1;
381        config.colsb[0] = 64;
382        config.rows[0] = 16;
383        _tile_loadconfig(config.as_ptr());
384        _tile_zero::<0>();
385        let mut out = [[1_i8; 64]; 16];
386        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
387        _tile_release();
388        assert_eq!(out, [[0; 64]; 16]);
389    }
390
391    #[simd_test(enable = "amx-tile")]
392    unsafe fn test_tile_stored() {
393        _init_amx();
394        let mut config = __tilecfg::default();
395        config.palette = 1;
396        config.colsb[0] = 64;
397        config.rows[0] = 16;
398        _tile_loadconfig(config.as_ptr());
399        _tile_zero::<0>();
400        let mut out = [[1_i8; 64]; 16];
401        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
402        _tile_release();
403        assert_eq!(out, [[0; 64]; 16]);
404    }
405
406    #[simd_test(enable = "amx-tile")]
407    unsafe fn test_tile_loadd() {
408        _init_amx();
409        let mut config = __tilecfg::default();
410        config.palette = 1;
411        config.colsb[0] = 64;
412        config.rows[0] = 16;
413        _tile_loadconfig(config.as_ptr());
414        _tile_zero::<0>();
415        let mat = [1_i8; 1024];
416        _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
417        let mut out = [[0_i8; 64]; 16];
418        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
419        _tile_release();
420        assert_eq!(out, [[1; 64]; 16]);
421    }
422
423    #[simd_test(enable = "amx-tile")]
424    unsafe fn test_tile_stream_loadd() {
425        _init_amx();
426        let mut config = __tilecfg::default();
427        config.palette = 1;
428        config.colsb[0] = 64;
429        config.rows[0] = 16;
430        _tile_loadconfig(config.as_ptr());
431        _tile_zero::<0>();
432        let mat = [1_i8; 1024];
433        _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
434        let mut out = [[0_i8; 64]; 16];
435        _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
436        _tile_release();
437        assert_eq!(out, [[1; 64]; 16]);
438    }
439
440    #[simd_test(enable = "amx-tile")]
441    unsafe fn test_tile_release() {
442        _tile_release();
443    }
444
445    #[simd_test(enable = "amx-bf16,avx512f")]
446    unsafe fn test_tile_dpbf16ps() {
447        _init_amx();
448        let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
449        let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
450        let ones: [u8; 1024] = transmute([bf16_1; 512]);
451        let twos: [u8; 1024] = transmute([bf16_2; 512]);
452        let mut res = [[0f32; 16]; 16];
453        let mut config = __tilecfg::default();
454        config.palette = 1;
455        (0..=2).for_each(|i| {
456            config.colsb[i] = 64;
457            config.rows[i] = 16;
458        });
459        _tile_loadconfig(config.as_ptr());
460        _tile_zero::<0>();
461        _tile_loadd::<1>(&ones as *const u8, 64);
462        _tile_loadd::<2>(&twos as *const u8, 64);
463        _tile_dpbf16ps::<0, 1, 2>();
464        _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
465        _tile_release();
466        assert_eq!(res, [[64f32; 16]; 16]);
467    }
468
469    #[simd_test(enable = "amx-int8")]
470    unsafe fn test_tile_dpbssd() {
471        _init_amx();
472        let ones = [-1_i8; 1024];
473        let twos = [-2_i8; 1024];
474        let mut res = [[0_i32; 16]; 16];
475        let mut config = __tilecfg::default();
476        config.palette = 1;
477        (0..=2).for_each(|i| {
478            config.colsb[i] = 64;
479            config.rows[i] = 16;
480        });
481        _tile_loadconfig(config.as_ptr());
482        _tile_zero::<0>();
483        _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
484        _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
485        _tile_dpbssd::<0, 1, 2>();
486        _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
487        _tile_release();
488        assert_eq!(res, [[128_i32; 16]; 16]);
489    }
490
491    #[simd_test(enable = "amx-int8")]
492    unsafe fn test_tile_dpbsud() {
493        _init_amx();
494        let ones = [-1_i8; 1024];
495        let twos = [2_u8; 1024];
496        let mut res = [[0_i32; 16]; 16];
497        let mut config = __tilecfg::default();
498        config.palette = 1;
499        (0..=2).for_each(|i| {
500            config.colsb[i] = 64;
501            config.rows[i] = 16;
502        });
503        _tile_loadconfig(config.as_ptr());
504        _tile_zero::<0>();
505        _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
506        _tile_loadd::<2>(&twos as *const u8, 64);
507        _tile_dpbsud::<0, 1, 2>();
508        _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
509        _tile_release();
510        assert_eq!(res, [[-128_i32; 16]; 16]);
511    }
512
513    #[simd_test(enable = "amx-int8")]
514    unsafe fn test_tile_dpbusd() {
515        _init_amx();
516        let ones = [1_u8; 1024];
517        let twos = [-2_i8; 1024];
518        let mut res = [[0_i32; 16]; 16];
519        let mut config = __tilecfg::default();
520        config.palette = 1;
521        (0..=2).for_each(|i| {
522            config.colsb[i] = 64;
523            config.rows[i] = 16;
524        });
525        _tile_loadconfig(config.as_ptr());
526        _tile_zero::<0>();
527        _tile_loadd::<1>(&ones as *const u8, 64);
528        _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
529        _tile_dpbusd::<0, 1, 2>();
530        _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
531        _tile_release();
532        assert_eq!(res, [[-128_i32; 16]; 16]);
533    }
534
535    #[simd_test(enable = "amx-int8")]
536    unsafe fn test_tile_dpbuud() {
537        _init_amx();
538        let ones = [1_u8; 1024];
539        let twos = [2_u8; 1024];
540        let mut res = [[0_i32; 16]; 16];
541        let mut config = __tilecfg::default();
542        config.palette = 1;
543        (0..=2).for_each(|i| {
544            config.colsb[i] = 64;
545            config.rows[i] = 16;
546        });
547        _tile_loadconfig(config.as_ptr());
548        _tile_zero::<0>();
549        _tile_loadd::<1>(&ones as *const u8, 64);
550        _tile_loadd::<2>(&twos as *const u8, 64);
551        _tile_dpbuud::<0, 1, 2>();
552        _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
553        _tile_release();
554        assert_eq!(res, [[128_i32; 16]; 16]);
555    }
556
557    #[simd_test(enable = "amx-fp16")]
558    unsafe fn test_tile_dpfp16ps() {
559        _init_amx();
560        let ones = [1f16; 512];
561        let twos = [2f16; 512];
562        let mut res = [[0f32; 16]; 16];
563        let mut config = __tilecfg::default();
564        config.palette = 1;
565        (0..=2).for_each(|i| {
566            config.colsb[i] = 64;
567            config.rows[i] = 16;
568        });
569        _tile_loadconfig(config.as_ptr());
570        _tile_zero::<0>();
571        _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
572        _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
573        _tile_dpfp16ps::<0, 1, 2>();
574        _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
575        _tile_release();
576        assert_eq!(res, [[64f32; 16]; 16]);
577    }
578
579    #[simd_test(enable = "amx-complex")]
580    unsafe fn test_tile_cmmimfp16ps() {
581        _init_amx();
582        let ones = [1f16; 512];
583        let twos = [2f16; 512];
584        let mut res = [[0f32; 16]; 16];
585        let mut config = __tilecfg::default();
586        config.palette = 1;
587        (0..=2).for_each(|i| {
588            config.colsb[i] = 64;
589            config.rows[i] = 16;
590        });
591        _tile_loadconfig(config.as_ptr());
592        _tile_zero::<0>();
593        _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
594        _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
595        _tile_cmmimfp16ps::<0, 1, 2>();
596        _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
597        _tile_release();
598        assert_eq!(res, [[64f32; 16]; 16]);
599    }
600
601    #[simd_test(enable = "amx-complex")]
602    unsafe fn test_tile_cmmrlfp16ps() {
603        _init_amx();
604        let ones = [1f16; 512];
605        let twos = [2f16; 512];
606        let mut res = [[0f32; 16]; 16];
607        let mut config = __tilecfg::default();
608        config.palette = 1;
609        (0..=2).for_each(|i| {
610            config.colsb[i] = 64;
611            config.rows[i] = 16;
612        });
613        _tile_loadconfig(config.as_ptr());
614        _tile_zero::<0>();
615        _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
616        _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
617        _tile_cmmrlfp16ps::<0, 1, 2>();
618        _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
619        _tile_release();
620        assert_eq!(res, [[0f32; 16]; 16]);
621    }
622}