@@ -4,6 +4,7 @@ use std::cmp;
44use  libc:: c_uint; 
55use  rustc_abi:: { BackendRepr ,  HasDataLayout ,  Primitive ,  Reg ,  RegKind ,  Size } ; 
66use  rustc_codegen_ssa:: MemFlags ; 
7+ use  rustc_codegen_ssa:: common:: TypeKind ; 
78use  rustc_codegen_ssa:: mir:: operand:: { OperandRef ,  OperandValue } ; 
89use  rustc_codegen_ssa:: mir:: place:: { PlaceRef ,  PlaceValue } ; 
910use  rustc_codegen_ssa:: traits:: * ; 
@@ -308,7 +309,7 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
308309} 
309310
310311pub ( crate )  trait  FnAbiLlvmExt < ' ll ,  ' tcx >  { 
311-     fn  llvm_type ( & self ,  cx :  & CodegenCx < ' ll ,  ' tcx > )  -> & ' ll  Type ; 
312+     fn  llvm_type ( & self ,  cx :  & CodegenCx < ' ll ,  ' tcx > ,   name :   & [ u8 ] )  -> & ' ll  Type ; 
312313    fn  ptr_to_llvm_type ( & self ,  cx :  & CodegenCx < ' ll ,  ' tcx > )  -> & ' ll  Type ; 
313314    fn  llvm_cconv ( & self ,  cx :  & CodegenCx < ' ll ,  ' tcx > )  -> llvm:: CallConv ; 
314315
@@ -325,26 +326,45 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
325326} 
326327
327328impl < ' ll ,  ' tcx >  FnAbiLlvmExt < ' ll ,  ' tcx >  for  FnAbi < ' tcx ,  Ty < ' tcx > >  { 
328-     fn  llvm_type ( & self ,  cx :  & CodegenCx < ' ll ,  ' tcx > )  -> & ' ll  Type  { 
329+     fn  llvm_type ( & self ,  cx :  & CodegenCx < ' ll ,  ' tcx > ,   name :   & [ u8 ] )  -> & ' ll  Type  { 
329330        // Ignore "extra" args from the call site for C variadic functions. 
330331        // Only the "fixed" args are part of the LLVM function signature. 
331332        let  args =
332333            if  self . c_variadic  {  & self . args [ ..self . fixed_count  as  usize ]  }  else  {  & self . args  } ; 
333334
335+         // todo(sayantn): a better way is to look at the `link_name` instead of the function name, because function name can be "faked" using `#[export_name]` 
336+         let  llvm_intrinsic = name. starts_with ( b"llvm." ) 
337+             && !self . c_variadic 
338+             && self . conv  == Conv :: C 
339+             && !self . can_unwind ; 
340+         let  amx_intrinsic =
341+             llvm_intrinsic && name. starts_with ( b"llvm.x86." )  && name. ends_with ( b".internal" ) ; 
342+         let  adjust_ty = |ty| { 
343+             // Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics 
344+             if  amx_intrinsic && cx. type_kind ( ty)  == TypeKind :: Vector  && cx. vector_length ( ty)  == 256 
345+             { 
346+                 let  element_ty = cx. element_type ( ty) ; 
347+                 if  cx. type_kind ( element_ty)  == TypeKind :: Integer  && cx. int_width ( element_ty)  == 32  { 
348+                     return  cx. type_x86amx ( ) ; 
349+                 } 
350+             } 
351+             ty
352+         } ; 
353+ 
334354        // This capacity calculation is approximate. 
335355        let  mut  llargument_tys = Vec :: with_capacity ( 
336356            self . args . len ( )  + if  let  PassMode :: Indirect  {  .. }  = self . ret . mode  {  1  }  else  {  0  } , 
337357        ) ; 
338358
339-         let  llreturn_ty = match  & self . ret . mode  { 
359+         let  llreturn_ty = adjust_ty ( match  & self . ret . mode  { 
340360            PassMode :: Ignore  => cx. type_void ( ) , 
341361            PassMode :: Direct ( _)  | PassMode :: Pair ( ..)  => self . ret . layout . immediate_llvm_type ( cx) , 
342362            PassMode :: Cast  {  cast,  pad_i32 :  _ }  => cast. llvm_type ( cx) , 
343363            PassMode :: Indirect  {  .. }  => { 
344364                llargument_tys. push ( cx. type_ptr ( ) ) ; 
345365                cx. type_void ( ) 
346366            } 
347-         } ; 
367+         } ) ; 
348368
349369        for  arg in  args { 
350370            // Note that the exact number of arguments pushed here is carefully synchronized with 
@@ -388,7 +408,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
388408                    cast. llvm_type ( cx) 
389409                } 
390410            } ; 
391-             llargument_tys. push ( llarg_ty) ; 
411+             llargument_tys. push ( adjust_ty ( llarg_ty) ) ; 
392412        } 
393413
394414        if  self . c_variadic  { 
0 commit comments