Skip to content

Commit

Permalink
Add cse-ed ops to opdict, unroll some short static vectors.
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Apr 25, 2020
1 parent 133bd96 commit ea2d6f0
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 8 deletions.
5 changes: 4 additions & 1 deletion src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ end
operations(ls::LoopSet) = ls.operations
function pushop!(ls::LoopSet, op::Operation, var::Symbol = name(op))
for opp operations(ls)
matches(op, opp) && return opp
if matches(op, opp)
ls.opdict[var] = opp
return opp
end
end
push!(ls.operations, op)
ls.opdict[var] = op
Expand Down
11 changes: 10 additions & 1 deletion src/lower_compute.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# A compute op needs to know the unrolling and tiling status of each of its parents.

function promote_to_231(op, unrolled, tiled)
unrolleddeps = Symbol[]
unrolled loopdependencies(op) && push!(unrolleddeps, unrolled)
tiled loopdependencies(op) && push!(unrolleddeps, tiled)
!any(opp -> isload(opp) && all(in(loopdependencies(opp)), unrolleddeps), parents(op))
end

struct FalseCollection end
Base.getindex(::FalseCollection, i...) = false
function lower_compute!(
Expand Down Expand Up @@ -76,7 +83,9 @@ function lower_compute!(
if !isnothing(suffix) && isreduct
# instrfid = findfirst(isequal(instr.instr), (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub))
instrfid = findfirst(isequal(instr.instr), (:vfmadd_fast, :vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
if instrfid !== nothing && !any(opp -> isload(opp) && all(in(loopdependencies(opp)), loopdependencies(op)), parents(op)) # want to instcombine when parent load's deps are superset
# want to instcombine when parent load's deps are superset
# also make sure opp is unrolled
if instrfid !== nothing && (opunrolled && U > 1) && promote_to_231(op, unrolled, tiled)
instr = Instruction((:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)[instrfid])
end
end
Expand Down
36 changes: 30 additions & 6 deletions src/lowering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,27 +172,49 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
nisvectorized = isvectorized(us, n)
loopisstatic = isstaticloop(loop) & (!nisvectorized)


remmask = inclmask | nisvectorized
Ureduct = (n == num_loops(ls) && (u₂ == -1)) ? calc_Ureduct(ls, us) : -1
sl = startloop(loop, nisvectorized, ls.W, loopsym)
tc = terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF)
body = lower_block(ls, us, n, inclmask, UF)

remfirst = loopisstatic & !(unsigned(Ureduct) < unsigned(UF))
if remfirst
tc = Expr(:call, lv(:scalar_less), loopsym, loop.stophint + 1)
else
tc = terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF)
end
body = lower_block(ls, us, n, inclmask, UF)
q = Expr(:while, tc, body)
remblock = init_remblock(loop, loopsym)
UFt = if loopisstatic
length(loop) % UF
else
1
end
q = if unsigned(Ureduct) < unsigned(UF) # unsigned(-1) == typemax(UInt); is logic relying on twos-complement bad?
Expr(
:block, sl,
add_upper_outer_reductions(ls, q, Ureduct, UF, loop, vectorized),
Expr(
:if, terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF - Ureduct),
lower_block(ls, us, n, inclmask, UF - Ureduct)
)
),
remblock
)
elseif remfirst
numiters = length(loop) ÷ UF
if numiters > 2
Expr( :block, sl, remblock, q )
else
q = Expr(:block, sl, remblock)
for i 1:numiters
push!(q.args, body)
end
q
end
else
Expr( :block, sl, q )
Expr( :block, sl, q, remblock )
end
remblock = init_remblock(loop, loopsym)
push!(q.args, remblock)
UFt = if loopisstatic
length(loop) % UF
else
Expand All @@ -206,6 +228,8 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
Expr(:call, :-, loop.stopsym, Expr(:call, lv(:valmul), ls.W, UFt))
end
Expr(:call, lv(:scalar_greater), loopsym, itercount)
elseif remfirst
Expr(:call, lv(:scalar_less), loopsym, loop.starthint + UFt)
elseif loop.stopexact
Expr(:call, lv(:scalar_greater), loopsym, loop.stophint - UFt)
else
Expand Down
24 changes: 24 additions & 0 deletions test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,26 @@
end; y
end

function csetanh!(y, z, x)
for j in axes(x, 2)
for i = axes(x, 1)
t2 = inv(tanh(x[i, j]))
t1 = tanh(x[i, j])
y[i, j] = z[i, j] * (-(1 - t1 ^ 2) * t2)
end
end
y
end
function csetanhavx!(y, z, x)
@avx for j in axes(x, 2)
for i = axes(x, 1)
t2 = inv(tanh(x[i, j]))
t1 = tanh(x[i, j])
y[i, j] = z[i, j] * (-(1 - t1 ^ 2) * t2)
end
end
y
end

for T (Float32, Float64)
@show T, @__LINE__
Expand Down Expand Up @@ -314,5 +334,9 @@
@test vpowf!(r1, x, -1.7) (r2 .= x .^ -1.7)
p = randn(length(x));
@test vpowf!(r1, x, x) (r2 .= x .^ x)

X = rand(T, N, M); Z = rand(T, N, M);
Y1 = similar(X); Y2 = similar(Y1);
@test csetanh!(Y1, X, Z) csetanhavx!(Y2, X, Z)
end
end

2 comments on commit ea2d6f0

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/13649

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.2 -m "<description of version>" ea2d6f0d6bb7c18a14c79b1b3c456373020d5886
git push origin v0.7.2

Please sign in to comment.