@@ -3854,13 +3854,36 @@ LLVM_Util::mask_as_int8(llvm::Value* mask)
38543854llvm::Value*
38553855LLVM_Util::mask4_as_int8 (llvm::Value* mask)
38563856{
3857- OSL_ASSERT (m_supports_llvm_bit_masks_natively);
3858- // combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it
3859- // to i8
3860- llvm::Value* zero_mask4
3861- = llvm::ConstantDataVector::getSplat (4 , constant_bool (false ));
3862- return builder ().CreateBitCast (op_combine_4x_vectors (mask, zero_mask4),
3863- type_int8 ());
3857+ if (m_supports_llvm_bit_masks_natively) {
3858+ // combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it
3859+ // to i8
3860+ llvm::Value* zero_mask4
3861+ = llvm::ConstantDataVector::getSplat (4 , constant_bool (false ));
3862+ return builder ().CreateBitCast (op_combine_4x_vectors (mask, zero_mask4),
3863+ type_int8 ());
3864+ } else {
3865+ // Convert <4 x i1> -> <4 x i32>
3866+ llvm::Value* wide_int_mask = builder ().CreateSExt (mask,
3867+ type_wide_int ());
3868+
3869+ // Now we will use the horizontal sign extraction intrinsic
3870+ // to build a 32 bit mask value. However the only 128bit
3871+ // version works on floats, so we will cast from int32 to
3872+ // float beforehand
3873+ llvm::Type* w4_float_type = llvm_vector_type (m_llvm_type_float, 4 );
3874+ llvm::Value* w4_float_mask = builder ().CreateBitCast (wide_int_mask,
3875+ w4_float_type);
3876+
3877+ llvm::Function* func = llvm::Intrinsic::getDeclaration (
3878+ module (), llvm::Intrinsic::x86_sse_movmsk_ps);
3879+
3880+ llvm::Value* args[1 ] = { w4_float_mask };
3881+ llvm::Value* int32 = builder ().CreateCall (func, toArrayRef (args));
3882+
3883+ llvm::Value* i8 = builder ().CreateIntCast (int32, type_int8 (), true );
3884+
3885+ return i8 ;
3886+ }
38643887}
38653888
38663889
@@ -4013,17 +4036,22 @@ LLVM_Util::op_1st_active_lane_of(llvm::Value* mask)
40134036 intMaskType = type_int8 ();
40144037 break ;
40154038 case 4 : {
4016- // We can just reinterpret cast a 4 bit mask to a 8 bit integer
4017- // and all types are happy
40184039 intMaskType = type_int8 ();
40194040
4020- // extended_int_vector_type = (llvm::Type *) llvm::VectorType::get(llvm::Type::getInt32Ty (*m_llvm_context), m_vector_width);
4021- // llvm::Value * wide_int_mask = builder().CreateSExt(mask, extended_int_vector_type);
4022- //
4023- // int_reinterpret_cast_vector_type = (llvm::Type *) llvm::Type::getInt128Ty (*m_llvm_context);
4024- // zeroConstant = constant128(0);
4025- //
4026- // llvm::Value * mask_as_int = builder().CreateBitCast (wide_int_mask, int_reinterpret_cast_vector_type);
4041+ llvm::Value* mask_as_int = mask4_as_int8 (mask);
4042+
4043+ // Count trailing zeros, least significant
4044+ llvm::Type* types[] = { intMaskType };
4045+ llvm::Function* func_cttz
4046+ = llvm::Intrinsic::getDeclaration (module (), llvm::Intrinsic::cttz,
4047+ toArrayRef (types));
4048+
4049+ llvm::Value* args[2 ] = { mask_as_int, constant_bool (true ) };
4050+
4051+ llvm::Value* firstNonZeroIndex = builder ().CreateCall (func_cttz,
4052+ toArrayRef (args));
4053+ return firstNonZeroIndex;
4054+
40274055 break ;
40284056 }
40294057 default : OSL_ASSERT (0 && " unsupported native bit mask width" );
0 commit comments