Skip to content

Commit

Permalink
feat: allow type-casting numbers to tracednumbers (#209)
Browse files Browse the repository at this point in the history
* feat: allow type-casting numbers to tracednumbers

* chore: apply formatting suggestion
  • Loading branch information
avik-pal authored Nov 1, 2024
1 parent f570fcc commit babeb7c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
return promote_to(TracedRNumber{T}, x)
end

TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x
function TracedRNumber{T}(x::Number) where {T}
return promote_to(TracedRNumber{T}, x)
end

function promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
if isa(rhs, TracedRNumber)
rhs isa TracedRNumber{T} && return rhs
Expand Down
10 changes: 10 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,13 @@ end
@test res3 isa ConcreteRArray
end
end

relu(x::T) where {T<:Number} = max(T(0), x)
relu(x) = relu.(x)

@testset "type casting" begin
x = randn(2, 10)
x_ra = Reactant.to_rarray(x)

@test @jit(relu(x_ra)) relu(x)
end

1 comment on commit babeb7c

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: babeb7c Previous: f570fcc Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1418797944 ns 1340567545 ns 1.06
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1230657063 ns 1354795677 ns 0.91
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1210055514 ns 1296358153 ns 0.93
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2321453182 ns 2617478292 ns 0.89
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 215031968 ns 207121854 ns 1.04
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 5458708327 ns 5245986343 ns 1.04
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5179301625 ns 5473784946 ns 0.95
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5152065959 ns 5562801011 ns 0.93
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 6914653384 ns 6785865699 ns 1.02
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 29634509034 ns 28788392011 ns 1.03
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1303933391 ns 1329681507 ns 0.98
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1288941570.5 ns 1310470804 ns 0.98
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1246884488 ns 1322248869 ns 0.94
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2588209027 ns 2593706297 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8825930 ns 8538279.5 ns 1.03
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1637260762 ns 1569392248 ns 1.04
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1607338067 ns 1563566948 ns 1.03
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1592753746 ns 1601308923.5 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 2888392716 ns 2743850639 ns 1.05
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2959415354 ns 2498208075 ns 1.18
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1320589513 ns 1314553261 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1232647002.5 ns 1520974050.5 ns 0.81
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1233197730.5 ns 1289846630 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2510219663 ns 2616640648 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 22686905 ns 21421171 ns 1.06
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2195365921 ns 2256971186 ns 0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2173148463 ns 2259019481 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2160517237 ns 2253385981 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3389252115 ns 3571597383 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 5458754250.5 ns 6399412347.5 ns 0.85
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1336344147 ns 1315428709.5 ns 1.02
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1284165465.5 ns 1285752949.5 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1264413606 ns 1352009040.5 ns 0.94
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2388755659 ns 2483558090 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7116389 ns 7375914.5 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1494584041 ns 1471123545 ns 1.02
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1490742502 ns 1460958972 ns 1.02
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1473980569 ns 1467352849 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 2796807816 ns 2771232207 ns 1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1669183460 ns 1067840903.5 ns 1.56
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1220367359.5 ns 1265463942 ns 0.96
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1264274640.5 ns 1336662627 ns 0.95
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1345724410.5 ns 1343776367 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2566724316 ns 2625246543 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 12278807 ns 15448376 ns 0.79
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1777269725 ns 1762453178 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1763977334 ns 1744480870 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1773537556 ns 1747667945 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3105794746 ns 3050720896 ns 1.02
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3076042064.5 ns 2931138896 ns 1.05
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1271346742 ns 1334489477 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1246562750 ns 1354072030 ns 0.92
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1309043330 ns 1326138010.5 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2442642621 ns 2585240240 ns 0.94
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 27314834 ns 25592502.5 ns 1.07
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2242544865 ns 2380176040 ns 0.94
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2216501128 ns 2263350697 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2196805969 ns 2217766749 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3556647163 ns 3470178449 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 5559034960 ns 7784948885.5 ns 0.71
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1242324757 ns 1267787779 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1298352031 ns 1334201675 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1230035861 ns 1233375595 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2637986128 ns 2426265268 ns 1.09
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 52652664 ns 50717982 ns 1.04
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3060315768 ns 3180815831 ns 0.96
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3106069884 ns 3006942334 ns 1.03
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3053865991 ns 3048786017 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4567618226 ns 4420144603 ns 1.03
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 9483960261 ns 8266182250 ns 1.15
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1231844420 ns 1289431106 ns 0.96
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1232961844 ns 1307821758 ns 0.94
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1246778659.5 ns 1319155148 ns 0.95
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2387803469 ns 2570243985 ns 0.93
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 70768943 ns 68121207.5 ns 1.04
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3264057828 ns 3184641683 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3289278023 ns 3196688266 ns 1.03
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3264052473 ns 3219193931 ns 1.01
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 4733831239 ns 4600542400 ns 1.03
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 10856363466 ns 14366625373 ns 0.76
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1204313355 ns 1248144802 ns 0.96
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1195727515.5 ns 1280560424 ns 0.93
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1220121107.5 ns 1260333857 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2407152921 ns 2541156555 ns 0.95
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 20638923 ns 19634575 ns 1.05
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1946630937 ns 1915886353 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1945684183 ns 1903714614 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1942898142 ns 1892177049 ns 1.03
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3272416211 ns 3107063618 ns 1.05
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3630039016.5 ns 3075524538.5 ns 1.18

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.