-
Notifications
You must be signed in to change notification settings - Fork 23
Open
Description
Generating HLO code then compiling fails due to array shape mismatch.
I have a similar issue for HLO code generated by JAX.
I think this was working in April or so but cannot find the precise version of Reactant I was using.
version: julia 1.10.9, Reactant 0.2.143 (Google Colab)
using Reactant: Ops, to_rarray, @compile
using Random
a = to_rarray(randn(Float32, 2,3))
b = copy(a)
@compile(
Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}
}
""", a, b)
) # works
f(a,b) = a.*b
@compile f(a,b) # works
code = string(@code_hlo f(a,b))
println(code)
f2(a,b) = Ops.hlo_call(code, a, b)
@compile f2(a,b) # fails
Metadata
Metadata
Assignees
Labels
No labels