diff --git a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java index cfc2b6722d9..e9c6da7698f 100644 --- a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java +++ b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java @@ -24,27 +24,23 @@ */ package hat.backend.ffi; - import hat.ComputeContext; import hat.Config; import hat.KernelContext; import hat.callgraph.KernelCallGraph; import hat.callgraph.MethodCallDag; -import jdk.incubator.code.CodeTransformer; -import optkl.Trxfmr; -import optkl.codebuilders.ScopedCodeBuilderContext; -import optkl.util.CallSite; -import optkl.ifacemapper.Buffer; -import optkl.ifacemapper.BoundSchema; -import optkl.ifacemapper.MappableIface; -import optkl.FuncOpParams; - - import jdk.incubator.code.CodeContext; +import jdk.incubator.code.CodeTransformer; import jdk.incubator.code.Op; import jdk.incubator.code.Value; import jdk.incubator.code.dialect.core.CoreOp; import jdk.incubator.code.dialect.core.SSA; +import optkl.FuncOpParams; +import optkl.Trxfmr; +import optkl.codebuilders.ScopedCodeBuilderContext; +import optkl.ifacemapper.BoundSchema; +import optkl.ifacemapper.Buffer; +import optkl.ifacemapper.MappableIface; import java.lang.foreign.Arena; import java.lang.invoke.MethodHandles; @@ -52,6 +48,7 @@ import java.util.HashSet; import java.util.List; import java.util.Set; + import static optkl.OpHelper.Invoke; import static optkl.OpHelper.Invoke.invoke; @@ -62,345 +59,346 @@ public class CudaBackend extends C99FFIBackend { final int addressSize = 64; final static HashMap mathFns = new HashMap<>(); + static { mathFns.put("log_float", """ - .func (.param .b32 func_retval0) log( - .param .b32 log_param_0 - ) - { - .reg .pred %p<4>; - .reg .f32 %f<36>; - .reg .b32 %r<5>; - ld.param.f32 %f5, [log_param_0]; - setp.lt.f32 %p1, %f5, 0f00800000; - mul.f32 %f6, %f5, 0f4B000000; - selp.f32 %f1, %f6, %f5, %p1; - selp.f32 %f7, 0fC1B80000, 0f00000000, %p1; - mov.b32 %r1, %f1; - add.s32 %r2, %r1, -1059760811; - and.b32 %r3, %r2, -8388608; - sub.s32 %r4, %r1, %r3; - mov.b32 %f8, %r4; - cvt.rn.f32.s32 %f9, %r3; - mov.f32 %f10, 0f34000000; - fma.rn.f32 %f11, %f9, %f10, %f7; - add.f32 %f12, %f8, 0fBF800000; - mov.f32 %f13, 0f3E1039F6; - mov.f32 %f14, 0fBE055027; - fma.rn.f32 %f15, %f14, %f12, %f13; - mov.f32 %f16, 0fBDF8CDCC; - fma.rn.f32 %f17, %f15, %f12, %f16; - mov.f32 %f18, 0f3E0F2955; - fma.rn.f32 %f19, %f17, %f12, %f18; - mov.f32 %f20, 0fBE2AD8B9; - fma.rn.f32 %f21, %f19, %f12, %f20; - mov.f32 %f22, 0f3E4CED0B; - fma.rn.f32 %f23, %f21, %f12, %f22; - mov.f32 %f24, 0fBE7FFF22; - fma.rn.f32 %f25, %f23, %f12, %f24; - mov.f32 %f26, 0f3EAAAA78; - fma.rn.f32 %f27, %f25, %f12, %f26; - mov.f32 %f28, 0fBF000000; - fma.rn.f32 %f29, %f27, %f12, %f28; - mul.f32 %f30, %f12, %f29; - fma.rn.f32 %f31, %f30, %f12, %f12; - mov.f32 %f32, 0f3F317218; - fma.rn.f32 %f35, %f11, %f32, %f31; - setp.lt.u32 %p2, %r1, 2139095040; - @%p2 bra $L__BB0_2; - mov.f32 %f33, 0f7F800000; - fma.rn.f32 %f35, %f1, %f33, %f33; - $L__BB0_2: - setp.eq.f32 %p3, %f1, 0f00000000; - selp.f32 %f34, 0fFF800000, %f35, %p3; - st.param.f32 [func_retval0+0], %f34; - ret; - }""" + .func (.param .b32 func_retval0) log( + .param .b32 log_param_0 + ) + { + .reg .pred %p<4>; + .reg .f32 %f<36>; + .reg .b32 %r<5>; + ld.param.f32 %f5, [log_param_0]; + setp.lt.f32 %p1, %f5, 0f00800000; + mul.f32 %f6, %f5, 0f4B000000; + selp.f32 %f1, %f6, %f5, %p1; + selp.f32 %f7, 0fC1B80000, 0f00000000, %p1; + mov.b32 %r1, %f1; + add.s32 %r2, %r1, -1059760811; + and.b32 %r3, %r2, -8388608; + sub.s32 %r4, %r1, %r3; + mov.b32 %f8, %r4; + cvt.rn.f32.s32 %f9, %r3; + mov.f32 %f10, 0f34000000; + fma.rn.f32 %f11, %f9, %f10, %f7; + add.f32 %f12, %f8, 0fBF800000; + mov.f32 %f13, 0f3E1039F6; + mov.f32 %f14, 0fBE055027; + fma.rn.f32 %f15, %f14, %f12, %f13; + mov.f32 %f16, 0fBDF8CDCC; + fma.rn.f32 %f17, %f15, %f12, %f16; + mov.f32 %f18, 0f3E0F2955; + fma.rn.f32 %f19, %f17, %f12, %f18; + mov.f32 %f20, 0fBE2AD8B9; + fma.rn.f32 %f21, %f19, %f12, %f20; + mov.f32 %f22, 0f3E4CED0B; + fma.rn.f32 %f23, %f21, %f12, %f22; + mov.f32 %f24, 0fBE7FFF22; + fma.rn.f32 %f25, %f23, %f12, %f24; + mov.f32 %f26, 0f3EAAAA78; + fma.rn.f32 %f27, %f25, %f12, %f26; + mov.f32 %f28, 0fBF000000; + fma.rn.f32 %f29, %f27, %f12, %f28; + mul.f32 %f30, %f12, %f29; + fma.rn.f32 %f31, %f30, %f12, %f12; + mov.f32 %f32, 0f3F317218; + fma.rn.f32 %f35, %f11, %f32, %f31; + setp.lt.u32 %p2, %r1, 2139095040; + @%p2 bra $L__BB0_2; + mov.f32 %f33, 0f7F800000; + fma.rn.f32 %f35, %f1, %f33, %f33; + $L__BB0_2: + setp.eq.f32 %p3, %f1, 0f00000000; + selp.f32 %f34, 0fFF800000, %f35, %p3; + st.param.f32 [func_retval0+0], %f34; + ret; + }""" ); mathFns.put("log_double", """ - .func (.param .b64 func_retval0) log( - .param .b64 log_param_0 - ) - { - .reg .pred %p<5>; - .reg .f32 %f<2>; - .reg .b32 %r<28>; - .reg .f64 %fd<59>; - ld.param.f64 %fd56, [log_param_0]; - { - .reg .b32 %temp; - mov.b64 {%temp, %r24}, %fd56; - } - { - .reg .b32 %temp; - mov.b64 {%r25, %temp}, %fd56; - } - setp.gt.s32 %p1, %r24, 1048575; - mov.u32 %r26, -1023; - @%p1 bra $L__BB0_2; - mul.f64 %fd56, %fd56, 0d4350000000000000; - { - .reg .b32 %temp; - mov.b64 {%temp, %r24}, %fd56; - } - { - .reg .b32 %temp; - mov.b64 {%r25, %temp}, %fd56; - } - mov.u32 %r26, -1077; - $L__BB0_2: - add.s32 %r13, %r24, -1; - setp.lt.u32 %p2, %r13, 2146435071; - @%p2 bra $L__BB0_4; - bra.uni $L__BB0_3; - $L__BB0_4: - shr.u32 %r15, %r24, 20; - add.s32 %r27, %r26, %r15; - and.b32 %r16, %r24, -2146435073; - or.b32 %r17, %r16, 1072693248; - mov.b64 %fd57, {%r25, %r17}; - setp.lt.s32 %p4, %r17, 1073127583; - @%p4 bra $L__BB0_6; - { - .reg .b32 %temp; - mov.b64 {%r18, %temp}, %fd57; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r19}, %fd57; - } - add.s32 %r20, %r19, -1048576; - mov.b64 %fd57, {%r18, %r20}; - add.s32 %r27, %r27, 1; - $L__BB0_6: - add.f64 %fd12, %fd57, 0d3FF0000000000000; - mov.f64 %fd13, 0d3FF0000000000000; - rcp.approx.ftz.f64 %fd14, %fd12; - neg.f64 %fd15, %fd12; - fma.rn.f64 %fd16, %fd15, %fd14, %fd13; - fma.rn.f64 %fd17, %fd16, %fd16, %fd16; - fma.rn.f64 %fd18, %fd17, %fd14, %fd14; - add.f64 %fd19, %fd57, 0dBFF0000000000000; - mul.f64 %fd20, %fd19, %fd18; - fma.rn.f64 %fd21, %fd19, %fd18, %fd20; - mul.f64 %fd22, %fd21, %fd21; - mov.f64 %fd23, 0d3ED0EE258B7A8B04; - mov.f64 %fd24, 0d3EB1380B3AE80F1E; - fma.rn.f64 %fd25, %fd24, %fd22, %fd23; - mov.f64 %fd26, 0d3EF3B2669F02676F; - fma.rn.f64 %fd27, %fd25, %fd22, %fd26; - mov.f64 %fd28, 0d3F1745CBA9AB0956; - fma.rn.f64 %fd29, %fd27, %fd22, %fd28; - mov.f64 %fd30, 0d3F3C71C72D1B5154; - fma.rn.f64 %fd31, %fd29, %fd22, %fd30; - mov.f64 %fd32, 0d3F624924923BE72D; - fma.rn.f64 %fd33, %fd31, %fd22, %fd32; - mov.f64 %fd34, 0d3F8999999999A3C4; - fma.rn.f64 %fd35, %fd33, %fd22, %fd34; - mov.f64 %fd36, 0d3FB5555555555554; - fma.rn.f64 %fd37, %fd35, %fd22, %fd36; - sub.f64 %fd38, %fd19, %fd21; - add.f64 %fd39, %fd38, %fd38; - neg.f64 %fd40, %fd21; - fma.rn.f64 %fd41, %fd40, %fd19, %fd39; - mul.f64 %fd42, %fd18, %fd41; - mul.f64 %fd43, %fd22, %fd37; - fma.rn.f64 %fd44, %fd43, %fd21, %fd42; - xor.b32 %r21, %r27, -2147483648; - mov.u32 %r22, -2147483648; - mov.u32 %r23, 1127219200; - mov.b64 %fd45, {%r21, %r23}; - mov.b64 %fd46, {%r22, %r23}; - sub.f64 %fd47, %fd45, %fd46; - mov.f64 %fd48, 0d3FE62E42FEFA39EF; - fma.rn.f64 %fd49, %fd47, %fd48, %fd21; - neg.f64 %fd50, %fd47; - fma.rn.f64 %fd51, %fd50, %fd48, %fd49; - sub.f64 %fd52, %fd51, %fd21; - sub.f64 %fd53, %fd44, %fd52; - mov.f64 %fd54, 0d3C7ABC9E3B39803F; - fma.rn.f64 %fd55, %fd47, %fd54, %fd53; - add.f64 %fd58, %fd49, %fd55; - bra.uni $L__BB0_7; - $L__BB0_3: - mov.f64 %fd10, 0d7FF0000000000000; - fma.rn.f64 %fd11, %fd56, %fd10, %fd10; - { - .reg .b32 %temp; - mov.b64 {%temp, %r14}, %fd56; - } - mov.b32 %f1, %r14; - setp.eq.f32 %p3, %f1, 0f00000000; - selp.f64 %fd58, 0dFFF0000000000000, %fd11, %p3; - $L__BB0_7: - st.param.f64 [func_retval0+0], %fd58; - ret; - }""" + .func (.param .b64 func_retval0) log( + .param .b64 log_param_0 + ) + { + .reg .pred %p<5>; + .reg .f32 %f<2>; + .reg .b32 %r<28>; + .reg .f64 %fd<59>; + ld.param.f64 %fd56, [log_param_0]; + { + .reg .b32 %temp; + mov.b64 {%temp, %r24}, %fd56; + } + { + .reg .b32 %temp; + mov.b64 {%r25, %temp}, %fd56; + } + setp.gt.s32 %p1, %r24, 1048575; + mov.u32 %r26, -1023; + @%p1 bra $L__BB0_2; + mul.f64 %fd56, %fd56, 0d4350000000000000; + { + .reg .b32 %temp; + mov.b64 {%temp, %r24}, %fd56; + } + { + .reg .b32 %temp; + mov.b64 {%r25, %temp}, %fd56; + } + mov.u32 %r26, -1077; + $L__BB0_2: + add.s32 %r13, %r24, -1; + setp.lt.u32 %p2, %r13, 2146435071; + @%p2 bra $L__BB0_4; + bra.uni $L__BB0_3; + $L__BB0_4: + shr.u32 %r15, %r24, 20; + add.s32 %r27, %r26, %r15; + and.b32 %r16, %r24, -2146435073; + or.b32 %r17, %r16, 1072693248; + mov.b64 %fd57, {%r25, %r17}; + setp.lt.s32 %p4, %r17, 1073127583; + @%p4 bra $L__BB0_6; + { + .reg .b32 %temp; + mov.b64 {%r18, %temp}, %fd57; + } + { + .reg .b32 %temp; + mov.b64 {%temp, %r19}, %fd57; + } + add.s32 %r20, %r19, -1048576; + mov.b64 %fd57, {%r18, %r20}; + add.s32 %r27, %r27, 1; + $L__BB0_6: + add.f64 %fd12, %fd57, 0d3FF0000000000000; + mov.f64 %fd13, 0d3FF0000000000000; + rcp.approx.ftz.f64 %fd14, %fd12; + neg.f64 %fd15, %fd12; + fma.rn.f64 %fd16, %fd15, %fd14, %fd13; + fma.rn.f64 %fd17, %fd16, %fd16, %fd16; + fma.rn.f64 %fd18, %fd17, %fd14, %fd14; + add.f64 %fd19, %fd57, 0dBFF0000000000000; + mul.f64 %fd20, %fd19, %fd18; + fma.rn.f64 %fd21, %fd19, %fd18, %fd20; + mul.f64 %fd22, %fd21, %fd21; + mov.f64 %fd23, 0d3ED0EE258B7A8B04; + mov.f64 %fd24, 0d3EB1380B3AE80F1E; + fma.rn.f64 %fd25, %fd24, %fd22, %fd23; + mov.f64 %fd26, 0d3EF3B2669F02676F; + fma.rn.f64 %fd27, %fd25, %fd22, %fd26; + mov.f64 %fd28, 0d3F1745CBA9AB0956; + fma.rn.f64 %fd29, %fd27, %fd22, %fd28; + mov.f64 %fd30, 0d3F3C71C72D1B5154; + fma.rn.f64 %fd31, %fd29, %fd22, %fd30; + mov.f64 %fd32, 0d3F624924923BE72D; + fma.rn.f64 %fd33, %fd31, %fd22, %fd32; + mov.f64 %fd34, 0d3F8999999999A3C4; + fma.rn.f64 %fd35, %fd33, %fd22, %fd34; + mov.f64 %fd36, 0d3FB5555555555554; + fma.rn.f64 %fd37, %fd35, %fd22, %fd36; + sub.f64 %fd38, %fd19, %fd21; + add.f64 %fd39, %fd38, %fd38; + neg.f64 %fd40, %fd21; + fma.rn.f64 %fd41, %fd40, %fd19, %fd39; + mul.f64 %fd42, %fd18, %fd41; + mul.f64 %fd43, %fd22, %fd37; + fma.rn.f64 %fd44, %fd43, %fd21, %fd42; + xor.b32 %r21, %r27, -2147483648; + mov.u32 %r22, -2147483648; + mov.u32 %r23, 1127219200; + mov.b64 %fd45, {%r21, %r23}; + mov.b64 %fd46, {%r22, %r23}; + sub.f64 %fd47, %fd45, %fd46; + mov.f64 %fd48, 0d3FE62E42FEFA39EF; + fma.rn.f64 %fd49, %fd47, %fd48, %fd21; + neg.f64 %fd50, %fd47; + fma.rn.f64 %fd51, %fd50, %fd48, %fd49; + sub.f64 %fd52, %fd51, %fd21; + sub.f64 %fd53, %fd44, %fd52; + mov.f64 %fd54, 0d3C7ABC9E3B39803F; + fma.rn.f64 %fd55, %fd47, %fd54, %fd53; + add.f64 %fd58, %fd49, %fd55; + bra.uni $L__BB0_7; + $L__BB0_3: + mov.f64 %fd10, 0d7FF0000000000000; + fma.rn.f64 %fd11, %fd56, %fd10, %fd10; + { + .reg .b32 %temp; + mov.b64 {%temp, %r14}, %fd56; + } + mov.b32 %f1, %r14; + setp.eq.f32 %p3, %f1, 0f00000000; + selp.f64 %fd58, 0dFFF0000000000000, %fd11, %p3; + $L__BB0_7: + st.param.f64 [func_retval0+0], %fd58; + ret; + }""" ); mathFns.put("exp_float", """ - .func (.param .b32 func_retval0) exp( - .param .b32 exp_param_0 - ) - { - .reg .f32 %f<18>; - .reg .b32 %r<3>; - ld.param.f32 %f1, [exp_param_0]; - mov.f32 %f2, 0f3F000000; - mov.f32 %f3, 0f3BBB989D; - fma.rn.f32 %f4, %f1, %f3, %f2; - mov.f32 %f5, 0f3FB8AA3B; - mov.f32 %f6, 0f437C0000; - cvt.sat.f32.f32 %f7, %f4; - mov.f32 %f8, 0f4B400001; - fma.rm.f32 %f9, %f7, %f6, %f8; - add.f32 %f10, %f9, 0fCB40007F; - neg.f32 %f11, %f10; - fma.rn.f32 %f12, %f1, %f5, %f11; - mov.f32 %f13, 0f32A57060; - fma.rn.f32 %f14, %f1, %f13, %f12; - mov.b32 %r1, %f9; - shl.b32 %r2, %r1, 23; - mov.b32 %f15, %r2; - ex2.approx.ftz.f32 %f16, %f14; - mul.f32 %f17, %f16, %f15; - st.param.f32 [func_retval0+0], %f17; - ret; - }""" + .func (.param .b32 func_retval0) exp( + .param .b32 exp_param_0 + ) + { + .reg .f32 %f<18>; + .reg .b32 %r<3>; + ld.param.f32 %f1, [exp_param_0]; + mov.f32 %f2, 0f3F000000; + mov.f32 %f3, 0f3BBB989D; + fma.rn.f32 %f4, %f1, %f3, %f2; + mov.f32 %f5, 0f3FB8AA3B; + mov.f32 %f6, 0f437C0000; + cvt.sat.f32.f32 %f7, %f4; + mov.f32 %f8, 0f4B400001; + fma.rm.f32 %f9, %f7, %f6, %f8; + add.f32 %f10, %f9, 0fCB40007F; + neg.f32 %f11, %f10; + fma.rn.f32 %f12, %f1, %f5, %f11; + mov.f32 %f13, 0f32A57060; + fma.rn.f32 %f14, %f1, %f13, %f12; + mov.b32 %r1, %f9; + shl.b32 %r2, %r1, 23; + mov.b32 %f15, %r2; + ex2.approx.ftz.f32 %f16, %f14; + mul.f32 %f17, %f16, %f15; + st.param.f32 [func_retval0+0], %f17; + ret; + }""" ); mathFns.put("exp_double", """ - .func (.param .b64 func_retval0) exp( - .param .b64 exp_param_0 - ) - { - .reg .pred %p<4>; - .reg .f32 %f<3>; - .reg .b32 %r<16>; - .reg .f64 %fd<41>; - ld.param.f64 %fd5, [exp_param_0]; - mov.f64 %fd6, 0d4338000000000000; - mov.f64 %fd7, 0d3FF71547652B82FE; - fma.rn.f64 %fd8, %fd5, %fd7, %fd6; - { - .reg .b32 %temp; - mov.b64 {%r1, %temp}, %fd8; - } - mov.f64 %fd9, 0dC338000000000000; - add.rn.f64 %fd10, %fd8, %fd9; - mov.f64 %fd11, 0dBFE62E42FEFA39EF; - fma.rn.f64 %fd12, %fd10, %fd11, %fd5; - mov.f64 %fd13, 0dBC7ABC9E3B39803F; - fma.rn.f64 %fd14, %fd10, %fd13, %fd12; - mov.f64 %fd15, 0d3E928AF3FCA213EA; - mov.f64 %fd16, 0d3E5ADE1569CE2BDF; - fma.rn.f64 %fd17, %fd16, %fd14, %fd15; - mov.f64 %fd18, 0d3EC71DEE62401315; - fma.rn.f64 %fd19, %fd17, %fd14, %fd18; - mov.f64 %fd20, 0d3EFA01997C89EB71; - fma.rn.f64 %fd21, %fd19, %fd14, %fd20; - mov.f64 %fd22, 0d3F2A01A014761F65; - fma.rn.f64 %fd23, %fd21, %fd14, %fd22; - mov.f64 %fd24, 0d3F56C16C1852B7AF; - fma.rn.f64 %fd25, %fd23, %fd14, %fd24; - mov.f64 %fd26, 0d3F81111111122322; - fma.rn.f64 %fd27, %fd25, %fd14, %fd26; - mov.f64 %fd28, 0d3FA55555555502A1; - fma.rn.f64 %fd29, %fd27, %fd14, %fd28; - mov.f64 %fd30, 0d3FC5555555555511; - fma.rn.f64 %fd31, %fd29, %fd14, %fd30; - mov.f64 %fd32, 0d3FE000000000000B; - fma.rn.f64 %fd33, %fd31, %fd14, %fd32; - mov.f64 %fd34, 0d3FF0000000000000; - fma.rn.f64 %fd35, %fd33, %fd14, %fd34; - fma.rn.f64 %fd36, %fd35, %fd14, %fd34; - { - .reg .b32 %temp; - mov.b64 {%r2, %temp}, %fd36; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r3}, %fd36; - } - shl.b32 %r4, %r1, 20; - add.s32 %r5, %r3, %r4; - mov.b64 %fd40, {%r2, %r5}; - { - .reg .b32 %temp; - mov.b64 {%temp, %r6}, %fd5; - } - mov.b32 %f2, %r6; - abs.f32 %f1, %f2; - setp.lt.f32 %p1, %f1, 0f4086232B; - @%p1 bra $L__BB0_3; + .func (.param .b64 func_retval0) exp( + .param .b64 exp_param_0 + ) + { + .reg .pred %p<4>; + .reg .f32 %f<3>; + .reg .b32 %r<16>; + .reg .f64 %fd<41>; + ld.param.f64 %fd5, [exp_param_0]; + mov.f64 %fd6, 0d4338000000000000; + mov.f64 %fd7, 0d3FF71547652B82FE; + fma.rn.f64 %fd8, %fd5, %fd7, %fd6; + { + .reg .b32 %temp; + mov.b64 {%r1, %temp}, %fd8; + } + mov.f64 %fd9, 0dC338000000000000; + add.rn.f64 %fd10, %fd8, %fd9; + mov.f64 %fd11, 0dBFE62E42FEFA39EF; + fma.rn.f64 %fd12, %fd10, %fd11, %fd5; + mov.f64 %fd13, 0dBC7ABC9E3B39803F; + fma.rn.f64 %fd14, %fd10, %fd13, %fd12; + mov.f64 %fd15, 0d3E928AF3FCA213EA; + mov.f64 %fd16, 0d3E5ADE1569CE2BDF; + fma.rn.f64 %fd17, %fd16, %fd14, %fd15; + mov.f64 %fd18, 0d3EC71DEE62401315; + fma.rn.f64 %fd19, %fd17, %fd14, %fd18; + mov.f64 %fd20, 0d3EFA01997C89EB71; + fma.rn.f64 %fd21, %fd19, %fd14, %fd20; + mov.f64 %fd22, 0d3F2A01A014761F65; + fma.rn.f64 %fd23, %fd21, %fd14, %fd22; + mov.f64 %fd24, 0d3F56C16C1852B7AF; + fma.rn.f64 %fd25, %fd23, %fd14, %fd24; + mov.f64 %fd26, 0d3F81111111122322; + fma.rn.f64 %fd27, %fd25, %fd14, %fd26; + mov.f64 %fd28, 0d3FA55555555502A1; + fma.rn.f64 %fd29, %fd27, %fd14, %fd28; + mov.f64 %fd30, 0d3FC5555555555511; + fma.rn.f64 %fd31, %fd29, %fd14, %fd30; + mov.f64 %fd32, 0d3FE000000000000B; + fma.rn.f64 %fd33, %fd31, %fd14, %fd32; + mov.f64 %fd34, 0d3FF0000000000000; + fma.rn.f64 %fd35, %fd33, %fd14, %fd34; + fma.rn.f64 %fd36, %fd35, %fd14, %fd34; + { + .reg .b32 %temp; + mov.b64 {%r2, %temp}, %fd36; + } + { + .reg .b32 %temp; + mov.b64 {%temp, %r3}, %fd36; + } + shl.b32 %r4, %r1, 20; + add.s32 %r5, %r3, %r4; + mov.b64 %fd40, {%r2, %r5}; + { + .reg .b32 %temp; + mov.b64 {%temp, %r6}, %fd5; + } + mov.b32 %f2, %r6; + abs.f32 %f1, %f2; + setp.lt.f32 %p1, %f1, 0f4086232B; + @%p1 bra $L__BB0_3; - setp.lt.f64 %p2, %fd5, 0d0000000000000000; - add.f64 %fd37, %fd5, 0d7FF0000000000000; - selp.f64 %fd40, 0d0000000000000000, %fd37, %p2; - setp.geu.f32 %p3, %f1, 0f40874800; - @%p3 bra $L__BB0_3; + setp.lt.f64 %p2, %fd5, 0d0000000000000000; + add.f64 %fd37, %fd5, 0d7FF0000000000000; + selp.f64 %fd40, 0d0000000000000000, %fd37, %p2; + setp.geu.f32 %p3, %f1, 0f40874800; + @%p3 bra $L__BB0_3; - shr.u32 %r7, %r1, 31; - add.s32 %r8, %r1, %r7; - shr.s32 %r9, %r8, 1; - shl.b32 %r10, %r9, 20; - add.s32 %r11, %r3, %r10; - mov.b64 %fd38, {%r2, %r11}; - sub.s32 %r12, %r1, %r9; - shl.b32 %r13, %r12, 20; - add.s32 %r14, %r13, 1072693248; - mov.u32 %r15, 0; - mov.b64 %fd39, {%r15, %r14}; - mul.f64 %fd40, %fd38, %fd39; - $L__BB0_3: - st.param.f64 [func_retval0+0], %fd40; - ret; - }""" + shr.u32 %r7, %r1, 31; + add.s32 %r8, %r1, %r7; + shr.s32 %r9, %r8, 1; + shl.b32 %r10, %r9, 20; + add.s32 %r11, %r3, %r10; + mov.b64 %fd38, {%r2, %r11}; + sub.s32 %r12, %r1, %r9; + shl.b32 %r13, %r12, 20; + add.s32 %r14, %r13, 1072693248; + mov.u32 %r15, 0; + mov.b64 %fd39, {%r15, %r14}; + mul.f64 %fd40, %fd38, %fd39; + $L__BB0_3: + st.param.f64 [func_retval0+0], %fd40; + ret; + }""" ); } - final Set usedMathFns = new HashSet<>(); public CudaBackend(Config config) { - super(Arena.global(), MethodHandles.lookup(),"cuda_backend", config); + super(Arena.global(), MethodHandles.lookup(), "cuda_backend", config); } public CudaBackend() { this(Config.fromEnvOrProperty()); } + @Override public void computeContextHandoff(ComputeContext computeContext) { - computeContext.computeCallGraph().callDag.entryPoint.funcOp(injectBufferTracking(config(),lookup(),computeContext.computeCallGraph().callDag.entryPoint.funcOp())); + computeContext.computeCallGraph().callDag.entryPoint.funcOp(injectBufferTracking(config(), lookup(), computeContext.computeCallGraph().callDag.entryPoint.funcOp())); } @Override public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernelContext, Object... args) { CompiledKernel compiledKernel = kernelCallGraphCompiledCodeMap.computeIfAbsent(kernelCallGraph, (_) -> { - String code =config().ptx() ? createPTX(kernelCallGraph, args) : createC99(kernelCallGraph, args); + String code = config().ptx() ? createPTX(kernelCallGraph, args) : createC99(kernelCallGraph, args); if (config().showCode()) { System.out.println(code); } var compilationUnit = backendBridge.compile(code); if (compilationUnit.ok()) { var kernel = compilationUnit.getKernel(kernelCallGraph.callDag.entryPoint.method().getName()); - return new CompiledKernel(this, kernelCallGraph, kernel, args); + return new CompiledKernel(this, kernelCallGraph, kernel, args); } else { throw new IllegalStateException("cuda failed to compile "); } }); compiledKernel.dispatch(kernelContext, args); } - String createC99(KernelCallGraph kernelCallGraph, Object... args){ - return createCode(kernelCallGraph, new CudaHATKernelBuilder(kernelCallGraph,new ScopedCodeBuilderContext(kernelCallGraph.lookup(),kernelCallGraph.callDag.entryPoint.funcOp())), args); + + String createC99(KernelCallGraph kernelCallGraph, Object... args) { + return createCode(kernelCallGraph, new CudaHATKernelBuilder(kernelCallGraph, new ScopedCodeBuilderContext(kernelCallGraph.lookup(), kernelCallGraph.callDag.entryPoint.funcOp())), args); } /// Same as OpenCL backend until here - - String createPTX(KernelCallGraph kernelCallGraph, Object... args){ + String createPTX(KernelCallGraph kernelCallGraph, Object... args) { var builder = new PTXHATKernelBuilder(); StringBuilder out = new StringBuilder(); StringBuilder invokedMethods = new StringBuilder(); @@ -413,36 +411,36 @@ String createPTX(KernelCallGraph kernelCallGraph, Object... args){ out.append(builder.getText()); builder.clear(); - // var here = CallSite.of(CudaBackend.class, "createPTX"); + // var here = CallSite.of(CudaBackend.class, "createPTX"); kernelCallGraph.callDag.rankOrdered.stream() - .filter(m->m instanceof MethodCallDag.OtherMethodCall) + .filter(m -> m instanceof MethodCallDag.OtherMethodCall) .forEach(f -> { CoreOp.FuncOp loweredFunc = f.funcOp().transform(CodeTransformer.LOWERING_TRANSFORMER); - loweredFunc = transformPTXPtrs(kernelCallGraph.lookup(),loweredFunc, argsMap, usedMathFns); - invokedMethods.append(createFunction(kernelCallGraph.lookup(),new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false)); + loweredFunc = transformPTXPtrs(kernelCallGraph.lookup(), loweredFunc, argsMap, usedMathFns); + invokedMethods.append(createFunction(kernelCallGraph.lookup(), new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false)); }); CoreOp.FuncOp lowered = kernelCallGraph.callDag.entryPoint.funcOp().transform(CodeTransformer.LOWERING_TRANSFORMER); - CoreOp.FuncOp loweredPtx = transformPTXPtrs(kernelCallGraph.lookup(),lowered, argsMap, usedMathFns); + CoreOp.FuncOp loweredPtx = transformPTXPtrs(kernelCallGraph.lookup(), lowered, argsMap, usedMathFns); for (String s : usedMathFns) { out.append("\n").append(mathFns.get(s)).append("\n"); } out.append(invokedMethods); - out.append(createFunction(kernelCallGraph.lookup(),builder.nl().nl(), loweredPtx, true)); - if (config().showKernelModel()){ - System.out.println("ptx follows\n"+out); + out.append(createFunction(kernelCallGraph.lookup(), builder.nl().nl(), loweredPtx, true)); + if (config().showKernelModel()) { + System.out.println("ptx follows\n" + out); } return out.toString(); } - static public CoreOp.FuncOp transformPTXPtrs(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp, HashMap argsMap, Set usedMathFns) { - return Trxfmr.of(lookup,funcOp).transform(_->true,(block, op) -> { + static public CoreOp.FuncOp transformPTXPtrs(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp, HashMap argsMap, Set usedMathFns) { + return Trxfmr.of(lookup, funcOp).transform(_ -> true, (block, op) -> { CodeContext cc = block.context(); // use first operand of invoke to figure out schema - if (invoke(lookup,op) instanceof Invoke invoke){ + if (invoke(lookup, op) instanceof Invoke invoke) { if (invoke.isMappableIface() && invoke.op().operands().getFirst() instanceof Op.Result invokeResult && invokeResult.op().operands().getFirst() instanceof Op.Result varLoadResult @@ -450,12 +448,12 @@ static public CoreOp.FuncOp transformPTXPtrs(MethodHandles.Lookup lookup,CoreOp && argsMap.get(varOp.varName()) instanceof Buffer buffer) { List inputOperands = invoke.op().operands(); List outputOperands = cc.getValues(inputOperands); - // Op.Result inputResult = invokeOp.result(); + // Op.Result inputResult = invokeOp.result(); BoundSchema boundSchema = MappableIface.getBoundSchema(buffer); PTXPtrOp ptxOp = new PTXPtrOp(invoke.returnType(), invoke.name(), outputOperands, boundSchema); Op.Result outputResult = block.op(ptxOp); cc.mapValue(invoke.op().result(), outputResult); - } else if (invoke.refIs(Math.class) && mathFns.containsKey(invoke.name() + "_" + invoke.returnType().toString())){ + } else if (invoke.refIs(Math.class) && mathFns.containsKey(invoke.name() + "_" + invoke.returnType().toString())) { usedMathFns.add(invoke.name() + "_" + invoke.returnType().toString()); block.op(op); } else { @@ -468,9 +466,8 @@ static public CoreOp.FuncOp transformPTXPtrs(MethodHandles.Lookup lookup,CoreOp }).funcOp(); } - static public String createFunction(MethodHandles.Lookup lookup,PTXHATKernelBuilder builder, CoreOp.FuncOp lowered, boolean entry) { - CoreOp.FuncOp ssa =SSA.transform(lowered); - + static public String createFunction(MethodHandles.Lookup lookup, PTXHATKernelBuilder builder, CoreOp.FuncOp lowered, boolean entry) { + CoreOp.FuncOp ssa = SSA.transform(lowered); // building fn info (name, params) builder.functionHeader(lowered.funcName(), entry, lowered.body().yieldType()); @@ -484,7 +481,7 @@ static public String createFunction(MethodHandles.Lookup lookup,PTXHATKernelBuil String out = builder.getText(); builder.clear(); ssa.bodies().getFirst().blocks().forEach(block -> - builder.blockBody(lookup,block, block.ops().stream())); + builder.blockBody(lookup, block, block.ops().stream())); builder.functionEpilogue(); String body = builder.getText(); diff --git a/hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend.cpp b/hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend.cpp index 730d9759237..6b78f10debd 100644 --- a/hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend.cpp +++ b/hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend.cpp @@ -154,7 +154,7 @@ PtxSource *CudaBackend::nvcc(const CudaSource *cudaSource) { // create var/cuda directory std::string localDirectory = "./var/cuda"; std::filesystem::create_directories(localDirectory); - // create temp file for cuda generarated code + // create temp file for cuda generated code const uint64_t time = timeSinceEpochMillisec(); const std::string ptxPath = tmpFileName(time, localDirectory, ".ptx"); const std::string cudaPath = tmpFileName(time, localDirectory, ".cu"); @@ -338,6 +338,12 @@ void CudaBackend::computeStart() { queue->computeStart(); } +std::string* CudaBackend::getDeviceVendor() { + // The CUDA Backend is owned by NVIDIA. Thus, no need to query + auto *vendor = new std::string("NVIDIA"); + return reinterpret_cast(vendor->data()); +} + bool CudaBackend::getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) { if (config->traceCalls) { std::cout << "getBufferFromDeviceIfDirty(" << std::hex << reinterpret_cast(memorySegment) << "," << diff --git a/hat/backends/ffi/cuda/src/main/native/include/cuda_backend.h b/hat/backends/ffi/cuda/src/main/native/include/cuda_backend.h index 760f5523bda..21cbe214a9d 100644 --- a/hat/backends/ffi/cuda/src/main/native/include/cuda_backend.h +++ b/hat/backends/ffi/cuda/src/main/native/include/cuda_backend.h @@ -104,23 +104,26 @@ class CudaSource final :public Text { class CudaBackend final : public Backend { public: -class CudaQueue final : public Backend::Queue { + class CudaQueue final : public Backend::Queue { public: std::thread::id streamCreationThread; CUstream cuStream; + explicit CudaQueue(Backend *backend); + void init(); + void wait() override; - void release() override; + void release() override; - void computeStart() override; + void computeStart() override; - void computeEnd() override; + void computeEnd() override; - void copyToDevice(Buffer *buffer) override; + void copyToDevice(Buffer *buffer) override; - void copyFromDevice(Buffer *buffer) override; + void copyFromDevice(Buffer *buffer) override; int estimateThreadsPerBlock(int dimensions); @@ -128,8 +131,9 @@ class CudaQueue final : public Backend::Queue { void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override; + ~CudaQueue() override; -}; + }; class CudaBuffer final : public Buffer { public: @@ -188,6 +192,8 @@ class CudaQueue final : public Backend::Queue { explicit CudaBackend(int mode); + std::string *getDeviceVendor() override; + ~CudaBackend() override; static CudaBackend * of(long backendHandle); static CudaBackend * of(Backend *backend); diff --git a/hat/backends/ffi/mock/src/main/java/hat/backend/ffi/MockBackend.java b/hat/backends/ffi/mock/src/main/java/hat/backend/ffi/MockBackend.java index 5b4ca727d9d..18122399566 100644 --- a/hat/backends/ffi/mock/src/main/java/hat/backend/ffi/MockBackend.java +++ b/hat/backends/ffi/mock/src/main/java/hat/backend/ffi/MockBackend.java @@ -24,7 +24,6 @@ */ package hat.backend.ffi; - import hat.ComputeContext; import hat.Config; import hat.KernelContext; @@ -36,13 +35,13 @@ public class MockBackend extends FFIBackend { public MockBackend(Arena arena, MethodHandles.Lookup lookup) { - super(arena,lookup,"mock_backend", Config.fromIntBits(0)); + super(arena, lookup, "mock_backend", Config.fromIntBits(0)); } @Override public void computeContextHandoff(ComputeContext computeContext) { System.out.println("Mock backend received closed closure"); - computeContext.computeCallGraph().callDag.entryPoint.funcOp((injectBufferTracking(config(),lookup(),computeContext.computeCallGraph().callDag.entryPoint.funcOp()))); + computeContext.computeCallGraph().callDag.entryPoint.funcOp((injectBufferTracking(config(), lookup(), computeContext.computeCallGraph().callDag.entryPoint.funcOp()))); } @Override diff --git a/hat/backends/ffi/mock/src/main/native/cpp/mock_backend.cpp b/hat/backends/ffi/mock/src/main/native/cpp/mock_backend.cpp index 4d34a5b8e9d..4ca80859bcb 100644 --- a/hat/backends/ffi/mock/src/main/native/cpp/mock_backend.cpp +++ b/hat/backends/ffi/mock/src/main/native/cpp/mock_backend.cpp @@ -135,6 +135,10 @@ class MockBackend final : public Backend { std::cout << "mock compute start()" << std::endl; } + std::string *getDeviceVendor() override { + return new std::string("Mock Vendor"); + } + CompilationUnit *compile(int len, char *source) override { std::cout << "mock compileProgram()" << std::endl; size_t srcLen = ::strlen(source); diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLBackend.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLBackend.java index 6b4fe0595a6..5db09922ec4 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLBackend.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLBackend.java @@ -24,7 +24,6 @@ */ package hat.backend.ffi; - import hat.ComputeContext; import hat.Config; import hat.KernelContext; @@ -36,26 +35,28 @@ public class OpenCLBackend extends C99FFIBackend { public OpenCLBackend(Config config) { - super(Arena.global(), MethodHandles.lookup(),"opencl_backend", config); + super(Arena.global(), MethodHandles.lookup(), "opencl_backend", config); } + public OpenCLBackend() { this(Config.fromEnvOrProperty()); } + @Override public void computeContextHandoff(ComputeContext computeContext) { - computeContext.computeCallGraph().callDag.entryPoint.funcOp(injectBufferTracking(config(),lookup(),computeContext.computeCallGraph().callDag.entryPoint.funcOp())); + computeContext.computeCallGraph().callDag.entryPoint.funcOp(injectBufferTracking(config(), lookup(), computeContext.computeCallGraph().callDag.entryPoint.funcOp())); } @Override public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernelContext, Object... args) { - CompiledKernel compiledKernel = kernelCallGraphCompiledCodeMap.computeIfAbsent(kernelCallGraph, (_) -> { - String code = createC99(kernelCallGraph, args); + CompiledKernel compiledKernel = kernelCallGraphCompiledCodeMap.computeIfAbsent(kernelCallGraph, (KernelCallGraph _) -> { + String code = createC99(kernelCallGraph, args); if (config().showCode()) { - System.out.println(code); + IO.println(code); } var compilationUnit = backendBridge.compile(code); if (compilationUnit.ok()) { - var kernel = compilationUnit.getKernel( kernelCallGraph.callDag.entryPoint.method().getName()); + var kernel = compilationUnit.getKernel(kernelCallGraph.callDag.entryPoint.method().getName()); return new CompiledKernel(this, kernelCallGraph, kernel, args); } else { // TODO: We should capture the log from OpenCL and provide as exception message @@ -65,8 +66,9 @@ public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernel compiledKernel.dispatch(kernelContext, args); } - String createC99(KernelCallGraph kernelCallGraph, Object[] args){ - return createCode(kernelCallGraph, new OpenCLHATKernelBuilder(kernelCallGraph, new ScopedCodeBuilderContext(kernelCallGraph.lookup(),kernelCallGraph.callDag.entryPoint.funcOp())), args); + String createC99(KernelCallGraph kernelCallGraph, Object[] args) { + kernelCallGraph.setDeviceVendor(backendBridge.getDeviceVendor()); + return createCode(kernelCallGraph, new OpenCLHATKernelBuilder(kernelCallGraph, new ScopedCodeBuilderContext(kernelCallGraph.lookup(), kernelCallGraph.callDag.entryPoint.funcOp())), args); } } diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index 7e73f40bc79..9c075024900 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -32,6 +32,7 @@ import hat.types.BF16; import hat.types.F16; import optkl.OpHelper; +import jdk.incubator.code.Op; import optkl.codebuilders.CodeBuilder; import jdk.incubator.code.Value; import jdk.incubator.code.dialect.core.CoreOp; @@ -50,7 +51,7 @@ public class OpenCLHATKernelBuilder extends C99HATKernelBuilder { protected OpenCLHATKernelBuilder(KernelCallGraph kernelCallGraph, ScopedCodeBuilderContext scopedCodeBuilderContext) { - super(kernelCallGraph,scopedCodeBuilderContext); + super(kernelCallGraph, scopedCodeBuilderContext); } @Override @@ -222,6 +223,7 @@ public OpenCLHATKernelBuilder hatF16ToFloatConvOp( HATF16Op.HATF16ToFloatConvOp // Mapping between API function names and OpenCL intrinsics for the math operations private static final Map MATH_FUNCTIONS = new HashMap<>(); + static { MATH_FUNCTIONS.put("maxf", "max"); MATH_FUNCTIONS.put("maxd", "max"); diff --git a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp index 0ac60dfd5f7..b6d30090cd4 100644 --- a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp +++ b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp @@ -172,6 +172,11 @@ void OpenCLBackend::computeEnd() { } } +std::string* OpenCLBackend::getDeviceVendor() { + const PlatformInfo platformInfo(this); + return new std::string(platformInfo.vendorName); +} + OpenCLBackend::OpenCLProgram *OpenCLBackend::compileProgram(OpenCLSource &openclSource) { return compileProgram(&openclSource); } diff --git a/hat/backends/ffi/opencl/src/main/native/include/opencl_backend.h b/hat/backends/ffi/opencl/src/main/native/include/opencl_backend.h index 989f03744b8..c4e9c0b73a8 100644 --- a/hat/backends/ffi/opencl/src/main/native/include/opencl_backend.h +++ b/hat/backends/ffi/opencl/src/main/native/include/opencl_backend.h @@ -161,6 +161,8 @@ class OpenCLBackend final : public Backend { void computeEnd() override; + std::string *getDeviceVendor() override; + bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) override; void shortDeviceInfo() override; diff --git a/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/FFIBackendDriver.java b/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/FFIBackendDriver.java index 67d5d8f70e7..bd6a6ea8654 100644 --- a/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/FFIBackendDriver.java +++ b/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/FFIBackendDriver.java @@ -24,7 +24,6 @@ */ package hat.backend.ffi; - import hat.Config; import hat.backend.Backend; import hat.buffer.ArgArray; @@ -32,6 +31,7 @@ import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.lang.invoke.MethodHandles; import java.util.HashMap; import java.util.Map; @@ -55,6 +55,7 @@ public static class KernelBridge { final FFILib.VoidHandleMethodPtr releaseKernel_MPtr; String name; final FFILib.LongHandleLongAddressMethodPtr ndrange_MPtr; + KernelBridge(CompilationUnitBridge compilationUnitBridge, String name, long handle) { this.compilationUnitBridge = compilationUnitBridge; this.handle = handle; @@ -62,9 +63,11 @@ public static class KernelBridge { this.ndrange_MPtr = compilationUnitBridge.backendBridge.ffiLib.longHandleLongAddressFunc("ndrange"); this.name = name; } + public void ndRange(ArgArray argArray) { this.ndrange_MPtr.invoke(handle, MappableIface.getMemorySegment(argArray)); } + void release() { releaseKernel_MPtr.invoke(handle); } @@ -86,12 +89,15 @@ void release() { this.compilationUnitOK_MPtr = backendBridge.ffiLib.booleanHandleFunc("compilationUnitOK"); this.getKernel_MPtr = backendBridge.ffiLib.longHandleIntAddressFunc("getKernel"); } + void release() { this.releaseCompilationUnit_MPtr.invoke(handle); } + boolean ok() { return this.compilationUnitOK_MPtr.invoke(handle); } + public KernelBridge getKernel(String kernelName) { return kernels.computeIfAbsent(kernelName, _ -> new KernelBridge(this, kernelName, @@ -111,6 +117,9 @@ public KernelBridge getKernel(String kernelName) { final FFILib.VoidHandleMethodPtr showDeviceInfo_MPtr; final FFILib.BooleanHandleAddressLongMethodPtr getBufferFromDeviceIfDirty_MPtr; + final FFILib.StringFunctionMethodPtr getVendorFunction; + final FFILib.StringFunctionLengthMethodPtr stringFunctionLength; + BackendBridge(FFILib ffiLib, Config config) { this.ffiLib = ffiLib; this.getBackend_MPtr = ffiLib.longHandleIntFunc("getBackend"); @@ -122,10 +131,13 @@ public KernelBridge getKernel(String kernelName) { this.showDeviceInfo_MPtr = ffiLib.voidHandleFunc("showDeviceInfo"); this.computeStart_MPtr = ffiLib.voidHandleFunc("computeStart"); this.computeEnd_MPtr = ffiLib.voidHandleFunc("computeEnd"); + this.getVendorFunction = ffiLib.stringHandleFunc("getDeviceVendor"); + this.stringFunctionLength = ffiLib.stringFunctionLengthMethodPtr("getStringLength"); this.getBufferFromDeviceIfDirty_MPtr = ffiLib.booleanHandleAddressLongFunc("getBufferFromDeviceIfDirty"); } - void release() {} + void release() { + } public long getBackend(int configBits) { return getBackend_MPtr.invoke(configBits); @@ -142,20 +154,29 @@ public CompilationUnitBridge compile(String source) { return compilationUnit(compilationUnitHandle, source); } + public Vendor getDeviceVendor() { + MemorySegment vendorNameSegment = getVendorFunction.invoke(handle); + long sizeString = stringFunctionLength.invoke(vendorNameSegment); + byte[] content = vendorNameSegment.reinterpret(sizeString).toArray(ValueLayout.JAVA_BYTE); + return Vendor.of(new String(content)); + } + public MappableIface getBufferFromDeviceIfDirty(MappableIface buffer) { MemorySegment memorySegment = MappableIface.getMemorySegment(buffer); - if (!getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize())){ + if (!getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize())) { throw new IllegalStateException("Failed to get buffer from backend"); } return buffer; - } + public void computeStart() { computeStart_MPtr.invoke(handle); } + public void computeEnd() { computeEnd_MPtr.invoke(handle); } + public void showDeviceInfo() { showDeviceInfo_MPtr.invoke(handle); } @@ -164,8 +185,8 @@ public void showDeviceInfo() { public final FFILib ffiLib; public final BackendBridge backendBridge; - public FFIBackendDriver(Arena arena, MethodHandles.Lookup lookup,String libName, Config config) { - super(arena,lookup,config); + protected FFIBackendDriver(Arena arena, MethodHandles.Lookup lookup, String libName, Config config) { + super(arena, lookup, config); this.ffiLib = new FFILib(libName); this.backendBridge = new BackendBridge(ffiLib, config); } diff --git a/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/FFILib.java b/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/FFILib.java index b51941e4106..8de1fe17c63 100644 --- a/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/FFILib.java +++ b/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/FFILib.java @@ -24,6 +24,7 @@ */ package hat.backend.ffi; +import java.lang.foreign.AddressLayout; import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.Linker; import java.lang.foreign.MemorySegment; @@ -36,14 +37,14 @@ import static java.lang.foreign.ValueLayout.JAVA_LONG; public class FFILib { - final public String name; + public final String name; public final boolean available; - final public Linker nativeLinker; + public final Linker nativeLinker; - final public SymbolLookup loaderLookup; + public final SymbolLookup loaderLookup; - public static class MethodPtr{ + public static class MethodPtr { final FFILib ffiLib; final FunctionDescriptor functionDescriptor; final MethodHandle mh; @@ -51,7 +52,7 @@ public static class MethodPtr{ MethodPtr(FFILib ffiLib, FunctionDescriptor descriptor, String name) { this.ffiLib = ffiLib; - this.functionDescriptor= descriptor; + this.functionDescriptor = descriptor; this.mh = ffiLib.loaderLookup.find(name) .map(symbolSegment -> ffiLib.nativeLinker.downcallHandle(symbolSegment, descriptor)) .orElse(null); @@ -60,16 +61,57 @@ public static class MethodPtr{ } this.name = name; } + } + + public static class StringFunctionMethodPtr extends MethodPtr { + StringFunctionMethodPtr(FFILib ffiLib, String name) { + super(ffiLib, FunctionDescriptor.of(ADDRESS, JAVA_LONG), name); + } + + public MemorySegment invoke(long handler) { + if (mh == null) { + throw new NullPointerException("Null methodhandle " + name); + } + try { + MemorySegment segment = (MemorySegment) mh.invoke(handler); + if (segment.equals(MemorySegment.NULL)) { + throw new IllegalStateException("Function " + name + " returned NULL"); + } + return segment; + } catch (Throwable e) { + throw new RuntimeException(e); + } + } } - public static class VoidAddressMethodPtr extends MethodPtr{ + public static class StringFunctionLengthMethodPtr extends MethodPtr { + + StringFunctionLengthMethodPtr(FFILib ffiLib, String name) { + super(ffiLib, FunctionDescriptor.of(AddressLayout.JAVA_LONG, ADDRESS), name); + } + + public long invoke(MemorySegment memorySegment) { + if (mh == null) { + throw new RuntimeException("Null methodhandle " + name); + } + try { + return (long) mh.invoke(memorySegment); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + } + + public static class VoidAddressMethodPtr extends MethodPtr { VoidAddressMethodPtr(FFILib ffiLib, String name) { - super(ffiLib,FunctionDescriptor.ofVoid(ADDRESS), name); + super(ffiLib, FunctionDescriptor.ofVoid(ADDRESS), name); } + public void invoke(MemorySegment memorySegment) { - if (mh == null){ - throw new RuntimeException("Null methodhandle "+name); + if (mh == null) { + throw new RuntimeException("Null methodhandle " + name); } try { mh.invoke(memorySegment); @@ -79,13 +121,14 @@ public void invoke(MemorySegment memorySegment) { } } - public static class VoidHandleMethodPtr extends MethodPtr{ + public static class VoidHandleMethodPtr extends MethodPtr { VoidHandleMethodPtr(FFILib ffiLib, String name) { super(ffiLib, FunctionDescriptor.ofVoid(JAVA_LONG), name); } + public void invoke(long handle) { - if (mh == null){ - throw new RuntimeException("Null methodhandle "+name); + if (mh == null) { + throw new RuntimeException("Null methodhandle " + name); } if (handle == 0) { throw new RuntimeException("handle is zero"); @@ -98,89 +141,94 @@ public void invoke(long handle) { } } - public static class BooleanHandleMethodPtr extends MethodPtr{ + public static class BooleanHandleMethodPtr extends MethodPtr { BooleanHandleMethodPtr(FFILib ffiLib, String name) { - super(ffiLib, FunctionDescriptor.of(JAVA_BOOLEAN,JAVA_LONG),name); + super(ffiLib, FunctionDescriptor.of(JAVA_BOOLEAN, JAVA_LONG), name); } + public boolean invoke(long handle) { - if (mh == null){ - throw new RuntimeException("Null methodhandle "+name); + if (mh == null) { + throw new RuntimeException("Null methodhandle " + name); } if (handle == 0L) { throw new IllegalArgumentException("handle is zero"); } try { - return (boolean)mh.invoke(handle); + return (boolean) mh.invoke(handle); } catch (Throwable e) { throw new RuntimeException(e); } } } - public static class BooleanHandleAddressLongMethodPtr extends MethodPtr{ + public static class BooleanHandleAddressLongMethodPtr extends MethodPtr { BooleanHandleAddressLongMethodPtr(FFILib ffiLib, String name) { - super(ffiLib, FunctionDescriptor.of(JAVA_BOOLEAN,JAVA_LONG,ADDRESS,JAVA_LONG), name); + super(ffiLib, FunctionDescriptor.of(JAVA_BOOLEAN, JAVA_LONG, ADDRESS, JAVA_LONG), name); } - public boolean invoke(long handle,MemorySegment memorySegment, long len) { - if (mh == null){ - throw new RuntimeException("Null methodhandle "+name); + + public boolean invoke(long handle, MemorySegment memorySegment, long len) { + if (mh == null) { + throw new RuntimeException("Null methodhandle " + name); } if (handle == 0L) { throw new IllegalArgumentException("handle is zero"); } try { - return (boolean)mh.invoke(handle, memorySegment, len); + return (boolean) mh.invoke(handle, memorySegment, len); } catch (Throwable e) { throw new RuntimeException(e); } } } - public static class LongHandleIntAddressMethodPtr extends MethodPtr{ + public static class LongHandleIntAddressMethodPtr extends MethodPtr { LongHandleIntAddressMethodPtr(FFILib ffiLib, String name) { - super(ffiLib, FunctionDescriptor.of(JAVA_LONG,JAVA_LONG,JAVA_INT,ADDRESS), name); + super(ffiLib, FunctionDescriptor.of(JAVA_LONG, JAVA_LONG, JAVA_INT, ADDRESS), name); } + public long invoke(long handle, int i, MemorySegment memorySegment) { - if (mh == null){ - throw new RuntimeException("Null methodhandle "+name); + if (mh == null) { + throw new RuntimeException("Null methodhandle " + name); } if (handle == 0L) { throw new IllegalArgumentException("handle is zero"); } try { - return (long)mh.invoke(handle, i, memorySegment); + return (long) mh.invoke(handle, i, memorySegment); } catch (Throwable e) { throw new RuntimeException(e); } } } - public static class LongHandleIntMethodPtr extends MethodPtr{ + public static class LongHandleIntMethodPtr extends MethodPtr { LongHandleIntMethodPtr(FFILib ffiLib, String name) { - super(ffiLib,FunctionDescriptor.of(JAVA_LONG,JAVA_INT), name); + super(ffiLib, FunctionDescriptor.of(JAVA_LONG, JAVA_INT), name); } - public long invoke( int i) { - if (mh == null){ - throw new RuntimeException("Null method handle trying to invoke "+ffiLib.name+"::"+name+"()"); + + public long invoke(int i) { + if (mh == null) { + throw new RuntimeException("Null method handle trying to invoke " + ffiLib.name + "::" + name + "()"); } try { - return (long)mh.invoke(i); + return (long) mh.invoke(i); } catch (Throwable e) { throw new RuntimeException(e); } } } - public static class LongHandleLongAddressMethodPtr extends MethodPtr{ + public static class LongHandleLongAddressMethodPtr extends MethodPtr { LongHandleLongAddressMethodPtr(FFILib ffiLib, String name) { - super(ffiLib,FunctionDescriptor.of(JAVA_LONG,JAVA_LONG,ADDRESS), name); + super(ffiLib, FunctionDescriptor.of(JAVA_LONG, JAVA_LONG, ADDRESS), name); } - public long invoke(long l, MemorySegment memorySegment) { - if (mh == null){ - throw new RuntimeException("Null methodhandle "+name); + + public long invoke(long l, MemorySegment memorySegment) { + if (mh == null) { + throw new RuntimeException("Null methodhandle " + name); } try { - return (long)mh.invoke(l, memorySegment); + return (long) mh.invoke(l, memorySegment); } catch (Throwable e) { throw new RuntimeException(e); } @@ -201,7 +249,6 @@ public FFILib(String name) { this.loaderLookup = SymbolLookup.loaderLookup(); } - public VoidAddressMethodPtr voidAddressFunc(String name) { return new VoidAddressMethodPtr(this, name); } @@ -209,20 +256,33 @@ public VoidAddressMethodPtr voidAddressFunc(String name) { public VoidHandleMethodPtr voidHandleFunc(String name) { return new VoidHandleMethodPtr(this, name); } + public BooleanHandleMethodPtr booleanHandleFunc(String name) { return new BooleanHandleMethodPtr(this, name); } + public BooleanHandleAddressLongMethodPtr booleanHandleAddressLongFunc(String name) { return new BooleanHandleAddressLongMethodPtr(this, name); } + public LongHandleIntAddressMethodPtr longHandleIntAddressFunc(String name) { return new LongHandleIntAddressMethodPtr(this, name); } + public LongHandleIntMethodPtr longHandleIntFunc(String name) { return new LongHandleIntMethodPtr(this, name); } + public LongHandleLongAddressMethodPtr longHandleLongAddressFunc(String name) { return new LongHandleLongAddressMethodPtr(this, name); } + public StringFunctionMethodPtr stringHandleFunc(String name) { + return new StringFunctionMethodPtr(this, name); + } + + public StringFunctionLengthMethodPtr stringFunctionLengthMethodPtr(String name) { + return new StringFunctionLengthMethodPtr(this, name); + } + } diff --git a/hat/backends/ffi/shared/src/main/native/cpp/shared.cpp b/hat/backends/ffi/shared/src/main/native/cpp/shared.cpp index 3ac981365c0..9e421b78306 100644 --- a/hat/backends/ffi/shared/src/main/native/cpp/shared.cpp +++ b/hat/backends/ffi/shared/src/main/native/cpp/shared.cpp @@ -144,6 +144,19 @@ extern "C" long compile(long backendHandle, int len, char *source) { return compilationUnitHandle; } +extern "C" const char* getDeviceVendor(long backendHandle) { + auto *backend = reinterpret_cast(backendHandle); + const std::string* str = backend->getDeviceVendor(); + return str->data(); +} + +extern "C" long getStringLength(std::string *str) { + if (str == nullptr) { + return 0; + } + return str->length(); +} + extern "C" long getKernel(long compilationUnitHandle, int nameLen, char *name) { if (INFO) { std::cout << "trampolining through programHandle to compilationUnit.getKernel()" diff --git a/hat/backends/ffi/shared/src/main/native/include/shared.h b/hat/backends/ffi/shared/src/main/native/include/shared.h index 82021d8ba62..8db85b12fcb 100644 --- a/hat/backends/ffi/shared/src/main/native/include/shared.h +++ b/hat/backends/ffi/shared/src/main/native/include/shared.h @@ -577,6 +577,8 @@ class Backend { virtual bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) = 0; virtual ~Backend() = default; + + virtual std::string* getDeviceVendor() = 0; }; template diff --git a/hat/backends/java/mt/src/main/java/hat/backend/java/JavaMultiThreadedBackend.java b/hat/backends/java/mt/src/main/java/hat/backend/java/JavaMultiThreadedBackend.java index 58e2f5b3a7e..5e6d55ae7c0 100644 --- a/hat/backends/java/mt/src/main/java/hat/backend/java/JavaMultiThreadedBackend.java +++ b/hat/backends/java/mt/src/main/java/hat/backend/java/JavaMultiThreadedBackend.java @@ -35,7 +35,7 @@ public class JavaMultiThreadedBackend extends JavaBackend { @Override public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernelContext, Object... args) { - // KernelEntrypoint kernelEntrypoint = kernelCallGraph.entrypoint; + // KernelEntrypoint kernelEntrypoint = kernelCallGraph.entrypoint; instance().forEachInRange(kernelContext, (kc) -> { Object[] a = Arrays.copyOf(args, args.length); // Annoying. we need to replace the args[0] but don't want to race other threads. try { @@ -57,5 +57,4 @@ synchronized WorkStealer instance() { return workStealer; } - } diff --git a/hat/backends/java/seq/src/main/java/hat/backend/java/JavaSequentialBackend.java b/hat/backends/java/seq/src/main/java/hat/backend/java/JavaSequentialBackend.java index feb664042c5..eba539947a4 100644 --- a/hat/backends/java/seq/src/main/java/hat/backend/java/JavaSequentialBackend.java +++ b/hat/backends/java/seq/src/main/java/hat/backend/java/JavaSequentialBackend.java @@ -33,7 +33,7 @@ public class JavaSequentialBackend extends JavaBackend { @Override public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernelContext, Object... args) { - // KernelEntrypoint kernelEntrypoint = kernelCallGraph.entrypoint; + // KernelEntrypoint kernelEntrypoint = kernelCallGraph.entrypoint; for (kernelContext.gix = 0; kernelContext.gix < kernelContext.gsx; kernelContext.gix++) { try { args[0] = kernelContext; diff --git a/hat/core/src/main/java/hat/backend/Backend.java b/hat/core/src/main/java/hat/backend/Backend.java index d272a60bd7f..967e127d565 100644 --- a/hat/core/src/main/java/hat/backend/Backend.java +++ b/hat/core/src/main/java/hat/backend/Backend.java @@ -24,13 +24,11 @@ */ package hat.backend; - import hat.Accelerator; import hat.ComputeContext; import hat.Config; import hat.KernelContext; -//import hat.backend.java.JavaMultiThreadedBackend; -//import hat.backend.java.JavaSequentialBackend; +import hat.callgraph.KernelCallGraph; import jdk.incubator.code.Op; import jdk.incubator.code.Value; import jdk.incubator.code.dialect.core.CoreOp; @@ -39,11 +37,8 @@ import optkl.OpHelper; import optkl.Trxfmr; import optkl.ifacemapper.AccessType; -import hat.callgraph.KernelCallGraph; import optkl.ifacemapper.MappableIface; import optkl.util.carriers.ArenaAndLookupCarrier; -import optkl.util.carriers.ArenaCarrier; -import optkl.util.carriers.LookupCarrier; import java.lang.foreign.Arena; import java.lang.invoke.MethodHandles; @@ -55,25 +50,33 @@ import static hat.ComputeContext.WRAPPER.MUTATE; import static optkl.OpHelper.Invoke.invoke; -public abstract class Backend implements ArenaAndLookupCarrier { +//import hat.backend.java.JavaMultiThreadedBackend; +//import hat.backend.java.JavaSequentialBackend; + +public abstract class Backend implements ArenaAndLookupCarrier { private final Config config; - public Config config(){ + public Config config() { return config; } private final Arena arena; - @Override public Arena arena(){ + + @Override + public Arena arena() { return arena; } + private final MethodHandles.Lookup lookup; - @Override public MethodHandles.Lookup lookup(){ + + @Override + public MethodHandles.Lookup lookup() { return lookup; } - protected Backend(Arena arena, MethodHandles.Lookup lookup,Config config){ + protected Backend(Arena arena, MethodHandles.Lookup lookup, Config config) { this.lookup = lookup; - this.arena =arena; + this.arena = arena; this.config = config; } @@ -105,9 +108,8 @@ public static Backend getBackend(Predicate backendPredicate) { public abstract void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernelContext, Object... args); - - public static CoreOp.FuncOp injectBufferTracking(Config config, MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { - var transformer = Trxfmr.of(lookup,funcOp); + public static CoreOp.FuncOp injectBufferTracking(Config config, MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { + var transformer = Trxfmr.of(lookup, funcOp); if (config.minimizeCopies()) { var paramTable = new FuncOpParams(funcOp); return transformer @@ -155,5 +157,4 @@ public static CoreOp.FuncOp injectBufferTracking(Config config, MethodHandles.L } } - } diff --git a/hat/core/src/main/java/hat/backend/ffi/Vendor.java b/hat/core/src/main/java/hat/backend/ffi/Vendor.java new file mode 100644 index 00000000000..92420014d96 --- /dev/null +++ b/hat/core/src/main/java/hat/backend/ffi/Vendor.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package hat.backend.ffi; + +import java.util.Locale; + +public interface Vendor { + + String name(); + + record Apple(String name) implements Vendor { + } + + record NVIDIA(String name) implements Vendor { + } + + record Intel(String name) implements Vendor { + } + + record AMD(String name) implements Vendor { + + } + + static Vendor of(String name) { + String canonicalName = name.toLowerCase(Locale.ENGLISH); + if (canonicalName.startsWith("apple")) { + return new Apple(name); + } else if (canonicalName.startsWith("nvidia")) { + return new NVIDIA(name); + } else if (canonicalName.startsWith("amd")) { + return new AMD(name); + } else if (canonicalName.startsWith("intel")) { + return new Intel(name); + } else { + throw new RuntimeException("Backend Vendor: " + name + " not implemented yet."); + } + } +} diff --git a/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java b/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java index 006fce8d904..986a3519be8 100644 --- a/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java +++ b/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java @@ -26,13 +26,14 @@ import hat.BufferTagger; import hat.KernelContext; +import hat.backend.ffi.Vendor; import hat.device.NonMappableIface; import hat.phases.HATTier; import hat.types.S16ImplOfF16; import hat.types.Tensor; import jdk.incubator.code.CodeTransformer; -import jdk.incubator.code.Op; import jdk.incubator.code.CodeType; +import jdk.incubator.code.Op; import jdk.incubator.code.dialect.core.CoreOp; import jdk.incubator.code.dialect.core.SSA; import jdk.incubator.code.dialect.java.ClassType; @@ -55,12 +56,17 @@ import static optkl.OpHelper.Invoke.invoke; public class KernelCallGraph implements LookupCarrier { - @Override public MethodHandles.Lookup lookup(){ + + private Vendor vendor; + + @Override + public MethodHandles.Lookup lookup() { return computeCallGraph.lookup(); } - public static final boolean showKernelCallDag = Boolean.getBoolean("showKernelCallDag"); - public static final boolean showKernelIfaceDag = Boolean.getBoolean("showKernelIfaceDag"); - public static final boolean showKernelIfaceDagProposedTypedefs = Boolean.getBoolean("showKernelIfaceDagProposedTypedefs"); + + public static final boolean showKernelCallDag = Boolean.getBoolean("showKernelCallDag"); + public static final boolean showKernelIfaceDag = Boolean.getBoolean("showKernelIfaceDag"); + public static final boolean showKernelIfaceDagProposedTypedefs = Boolean.getBoolean("showKernelIfaceDagProposedTypedefs"); public final ComputeCallGraph computeCallGraph; public final MethodCallDag callDag; @@ -82,18 +88,18 @@ public class KernelCallGraph implements LookupCarrier { this.computeCallGraph = computeCallGraph; - CoreOp.FuncOp ssaFunc = SSA.transform( e.transform(CodeTransformer.LOWERING_TRANSFORMER)) ; - var changed = Mutable.of(true); + CoreOp.FuncOp ssaFunc = SSA.transform(e.transform(CodeTransformer.LOWERING_TRANSFORMER)); + var changed = Mutable.of(true); while (changed.get()) { // loop until no more inline-able functions changed.set(false); - ssaFunc = ssaFunc.transform( (blockbuilder, op) -> { + ssaFunc = ssaFunc.transform((blockbuilder, op) -> { if (invoke(lookup(), op) instanceof OpHelper.Invoke invoke // always but pattern friendly && invoke.resolvedMethodOrNull() instanceof Method m && Op.ofMethod(m) instanceof Optional optionalFuncOp // always but pattern friendly && optionalFuncOp.isPresent() && optionalFuncOp.get() instanceof CoreOp.FuncOp inline // always we just want var in scope - ){ - var ssaInline =SSA.transform(inline.transform(CodeTransformer.LOWERING_TRANSFORMER)); + ) { + var ssaInline = SSA.transform(inline.transform(CodeTransformer.LOWERING_TRANSFORMER)); var exitBlockBuilder = jdk.incubator.code.dialect.core.Inliner.inline( blockbuilder, ssaInline, blockbuilder.context().getValues(invoke.op().operands()), (_, _value) -> { @@ -125,20 +131,20 @@ public class KernelCallGraph implements LookupCarrier { this.accessedClasses = this.accessedTypes.stream() .filter(te -> te instanceof ClassType).map(te -> (Class) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) te)) .collect(Collectors.toSet()); - this.accessedIfaceClasses = this.accessedClasses.stream() - .filter(c->IfaceValue.class.isAssignableFrom(c)).map(c->(Class)c) + this.accessedIfaceClasses = this.accessedClasses.stream() + .filter(c -> IfaceValue.class.isAssignableFrom(c)).map(c -> (Class) c) .collect(Collectors.toSet()); - this.accessedMappableIfaceClasses = this.accessedIfaceClasses.stream() - .filter(c->MappableIface.class.isAssignableFrom(c)).map(c->(Class)c) + this.accessedMappableIfaceClasses = this.accessedIfaceClasses.stream() + .filter(c -> MappableIface.class.isAssignableFrom(c)).map(c -> (Class) c) .collect(Collectors.toSet()); - this.accessedNonMappableIfaceClasses = this.accessedIfaceClasses.stream() - .filter(c->NonMappableIface.class.isAssignableFrom(c)).map(c->(Class)c) + this.accessedNonMappableIfaceClasses = this.accessedIfaceClasses.stream() + .filter(c -> NonMappableIface.class.isAssignableFrom(c)).map(c -> (Class) c) .collect(Collectors.toSet()); - this.accessedVecClasses = this.accessedClasses.stream() - .filter(c->IfaceValue.vec.class.isAssignableFrom(c)).map(c->(Class)c) + this.accessedVecClasses = this.accessedClasses.stream() + .filter(c -> IfaceValue.vec.class.isAssignableFrom(c)).map(c -> (Class) c) .collect(Collectors.toSet()); - this.accessedFP16Classes = this.accessedClasses.stream() - .filter(c-> S16ImplOfF16.class.isAssignableFrom(c)).map(c->(Class)c) + this.accessedFP16Classes = this.accessedClasses.stream() + .filter(c -> S16ImplOfF16.class.isAssignableFrom(c)).map(c -> (Class) c) .collect(Collectors.toSet()); this.usesAtomics = OpHelper.Invoke.stream(lookup(), inlinedEntryPoint) .anyMatch(invoke -> @@ -147,9 +153,6 @@ public class KernelCallGraph implements LookupCarrier { && invoke.returnsInt() && invoke.nameMatchesRegex("(atomic.*)Inc")); - - - this.bufferAccessList = BufferTagger.getAccessList(lookup(), inlinedEntryPoint); var entrypoint = new FuncOpCarrier.Impl(e); @@ -163,14 +166,14 @@ public class KernelCallGraph implements LookupCarrier { this.callDag.view("kernelCallDag", n -> n.funcOp().funcName()); } - this.ifaceDag = new IfaceDataDag<>(dag-> - entrypoint.funcOp().elements() - .filter(ce -> ce instanceof Op).map(ce -> ((Op) ce).resultType()) - .filter(codeType -> codeType instanceof ClassType).map(codeType -> dag.getNode(lookup(), (ClassType) codeType)) - .filter(impl -> IfaceValue.class.isAssignableFrom(impl.clazz())) - .forEach(iface -> dag.methodsWithIfaceReturnTypes(iface.clazz()) - .forEach(retType -> dag.addEdge(iface, retType)) - ) + this.ifaceDag = new IfaceDataDag<>(dag -> + entrypoint.funcOp().elements() + .filter(ce -> ce instanceof Op).map(ce -> ((Op) ce).resultType()) + .filter(codeType -> codeType instanceof ClassType).map(codeType -> dag.getNode(lookup(), (ClassType) codeType)) + .filter(impl -> IfaceValue.class.isAssignableFrom(impl.clazz())) + .forEach(iface -> dag.methodsWithIfaceReturnTypes(iface.clazz()) + .forEach(retType -> dag.addEdge(iface, retType)) + ) ); if (showKernelIfaceDag) { this.ifaceDag.view("kernelDataDag", IfaceDataDag.IfaceInfo::dotName); @@ -179,4 +182,12 @@ public class KernelCallGraph implements LookupCarrier { ifaceDag.rankOrdered.forEach(ifaceInfo -> System.out.println("create typedef " + ifaceInfo.classType())); } } + + public void setDeviceVendor(Vendor vendor) { + this.vendor = vendor; + } + + public Vendor getDeviceVendor() { + return this.vendor; + } }