From ea2d6f0d6bb7c18a14c79b1b3c456373020d5886 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Sat, 25 Apr 2020 15:06:34 -0400 Subject: [PATCH] Add cse-ed ops to opdict, unroll some short static vectors. --- src/graphs.jl | 5 ++++- src/lower_compute.jl | 11 ++++++++++- src/lowering.jl | 36 ++++++++++++++++++++++++++++++------ test/special.jl | 24 ++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 8 deletions(-) diff --git a/src/graphs.jl b/src/graphs.jl index ec15e963b..42dd78d43 100644 --- a/src/graphs.jl +++ b/src/graphs.jl @@ -322,7 +322,10 @@ end operations(ls::LoopSet) = ls.operations function pushop!(ls::LoopSet, op::Operation, var::Symbol = name(op)) for opp ∈ operations(ls) - matches(op, opp) && return opp + if matches(op, opp) + ls.opdict[var] = opp + return opp + end end push!(ls.operations, op) ls.opdict[var] = op diff --git a/src/lower_compute.jl b/src/lower_compute.jl index 5f57f01ff..98a07f3ea 100644 --- a/src/lower_compute.jl +++ b/src/lower_compute.jl @@ -1,5 +1,12 @@ # A compute op needs to know the unrolling and tiling status of each of its parents. +function promote_to_231(op, unrolled, tiled) + unrolleddeps = Symbol[] + unrolled ∈ loopdependencies(op) && push!(unrolleddeps, unrolled) + tiled ∈ loopdependencies(op) && push!(unrolleddeps, tiled) + !any(opp -> isload(opp) && all(in(loopdependencies(opp)), unrolleddeps), parents(op)) +end + struct FalseCollection end Base.getindex(::FalseCollection, i...) = false function lower_compute!( @@ -76,7 +83,9 @@ function lower_compute!( if !isnothing(suffix) && isreduct # instrfid = findfirst(isequal(instr.instr), (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub)) instrfid = findfirst(isequal(instr.instr), (:vfmadd_fast, :vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast)) - if instrfid !== nothing && !any(opp -> isload(opp) && all(in(loopdependencies(opp)), loopdependencies(op)), parents(op)) # want to instcombine when parent load's deps are superset + # want to instcombine when parent load's deps are superset + # also make sure opp is unrolled + if instrfid !== nothing && (opunrolled && U > 1) && promote_to_231(op, unrolled, tiled) instr = Instruction((:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)[instrfid]) end end diff --git a/src/lowering.jl b/src/lowering.jl index 917a95d25..28cfbfe72 100644 --- a/src/lowering.jl +++ b/src/lowering.jl @@ -172,13 +172,25 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in nisvectorized = isvectorized(us, n) loopisstatic = isstaticloop(loop) & (!nisvectorized) + remmask = inclmask | nisvectorized Ureduct = (n == num_loops(ls) && (u₂ == -1)) ? calc_Ureduct(ls, us) : -1 sl = startloop(loop, nisvectorized, ls.W, loopsym) - tc = terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF) - body = lower_block(ls, us, n, inclmask, UF) + remfirst = loopisstatic & !(unsigned(Ureduct) < unsigned(UF)) + if remfirst + tc = Expr(:call, lv(:scalar_less), loopsym, loop.stophint + 1) + else + tc = terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF) + end + body = lower_block(ls, us, n, inclmask, UF) q = Expr(:while, tc, body) + remblock = init_remblock(loop, loopsym) + UFt = if loopisstatic + length(loop) % UF + else + 1 + end q = if unsigned(Ureduct) < unsigned(UF) # unsigned(-1) == typemax(UInt); is logic relying on twos-complement bad? Expr( :block, sl, @@ -186,13 +198,23 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in Expr( :if, terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF - Ureduct), lower_block(ls, us, n, inclmask, UF - Ureduct) - ) + ), + remblock ) + elseif remfirst + numiters = length(loop) ÷ UF + if numiters > 2 + Expr( :block, sl, remblock, q ) + else + q = Expr(:block, sl, remblock) + for i ∈ 1:numiters + push!(q.args, body) + end + q + end else - Expr( :block, sl, q ) + Expr( :block, sl, q, remblock ) end - remblock = init_remblock(loop, loopsym) - push!(q.args, remblock) UFt = if loopisstatic length(loop) % UF else @@ -206,6 +228,8 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in Expr(:call, :-, loop.stopsym, Expr(:call, lv(:valmul), ls.W, UFt)) end Expr(:call, lv(:scalar_greater), loopsym, itercount) + elseif remfirst + Expr(:call, lv(:scalar_less), loopsym, loop.starthint + UFt) elseif loop.stopexact Expr(:call, lv(:scalar_greater), loopsym, loop.stophint - UFt) else diff --git a/test/special.jl b/test/special.jl index 28d1a03b6..e40591100 100644 --- a/test/special.jl +++ b/test/special.jl @@ -251,6 +251,26 @@ end; y end + function csetanh!(y, z, x) + for j in axes(x, 2) + for i = axes(x, 1) + t2 = inv(tanh(x[i, j])) + t1 = tanh(x[i, j]) + y[i, j] = z[i, j] * (-(1 - t1 ^ 2) * t2) + end + end + y + end + function csetanhavx!(y, z, x) + @avx for j in axes(x, 2) + for i = axes(x, 1) + t2 = inv(tanh(x[i, j])) + t1 = tanh(x[i, j]) + y[i, j] = z[i, j] * (-(1 - t1 ^ 2) * t2) + end + end + y + end for T ∈ (Float32, Float64) @show T, @__LINE__ @@ -314,5 +334,9 @@ @test vpowf!(r1, x, -1.7) ≈ (r2 .= x .^ -1.7) p = randn(length(x)); @test vpowf!(r1, x, x) ≈ (r2 .= x .^ x) + + X = rand(T, N, M); Z = rand(T, N, M); + Y1 = similar(X); Y2 = similar(Y1); + @test csetanh!(Y1, X, Z) ≈ csetanhavx!(Y2, X, Z) end end