Skip to content

Commit

Permalink
Get tests to pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Mar 30, 2020
1 parent 9bfa9a5 commit 8771d75
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
8 changes: 6 additions & 2 deletions src/determinestrategy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,6 @@ function choose_unroll_order(ls::LoopSet, lowest_cost::Float64 = Inf)
end
function choose_tile(ls::LoopSet)
lo = LoopOrders(ls)
# @show lo.syms ls.loop_order.bestorder
best_order = copyto!(ls.loop_order.bestorder, lo.syms)
best_unrolled = best_tiled = best_vec = first(best_order) # filler
new_order, state = iterate(lo) # right now, new_order === best_order
Expand Down Expand Up @@ -651,7 +650,8 @@ function choose_tile(ls::LoopSet)
end
end
# Last in order is the inner most loop
function choose_order(ls::LoopSet)
function choose_order_cost(ls::LoopSet)
resize!(ls.loop_order, length(ls.loopsymbols))
if num_loops(ls) > 1
torder, tunroll, ttile, tvec, tU, tT, tc = choose_tile(ls)
else
Expand All @@ -665,6 +665,10 @@ function choose_order(ls::LoopSet)
return uorder, first(uorder), Symbol("##undefined##"), uvec, determine_unroll_factor(ls, uorder, first(uorder), uvec), -1, uc
end
end
function choose_order(ls::LoopSet)
order, unroll, tile, vec, U, T, c = choose_order_cost(ls)
order, unroll, tile, vec, U, T
end

function register_pressure(ls::LoopSet, U, T)
if T == -1
Expand Down
1 change: 1 addition & 0 deletions src/operation_evaluation_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ end

function fillorder!(ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, tiled::Symbol, loopistiled::Bool)
lo = ls.loop_order
resize!(lo, length(ls.loopsymbols))
ro = lo.loopnames # reverse order; will have same order as lo
nloops = length(order)
ops = operations(ls)
Expand Down
27 changes: 16 additions & 11 deletions src/split_loops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ function add_operation!(ls_new::LoopSet, included::Vector{Int}, ls::LoopSet, op:
end
opnew = Operation(
length(operations(ls_new)), name(op), op.elementbytes, instruction(op), op.node_type,
loopdependencies(op), reduceddependencies(op), vparents, ref(op), reducedchildren(op)
loopdependencies(op), reduceddependencies(op), vparents, op.ref, reducedchildren(op)
)
push!(operations(ls_new), opnew)
included[identifier(op)] = identifier(opnew)
opnew
end
Expand Down Expand Up @@ -45,33 +46,37 @@ function split_loopset(ls::LoopSet, ids)
ls_new
end


function lower_and_split_loops(ls::LoopSet)
function returned_ops(ls::LoopSet)
ops = operations(ls)
split_candidates = Int[]
retops = Int[]
for op ops
isstore(op) && push!(split_candidates, identifier(op))
isstore(op) && push!(retops, identifier(op))
end
for i ls.outer_reductions
push!(split_candidates, i)
push!(retops, i)
end
retops
end

function lower_and_split_loops(ls::LoopSet)
split_candidates = returned_ops(ls)
length(split_candidates) > 1 || return lower(ls)
order_fused, unrolled_fused, tiled_fused, vectorized_fused, U_fused, T_fused, cost_fused = choose_order(ls)
order_fused, unrolled_fused, tiled_fused, vectorized_fused, U_fused, T_fused, cost_fused = choose_order_cost(ls)
remaining_ops = Vector{Int}(undef, length(split_candidates) - 1); split_1 = Int[0];
for (ind,i) enumerate(split_candidates)
split_1[1] = i
ls_1 = split_loopset(ls, split_1)
order_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1, cost_1 = choose_order(ls_1)
reaminig_ops[1:ind-1] .= @view(split_candidates[1:ind-1]); reaminig_ops[ind:end] .= @view(split_candidates[ind+1:end])
order_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1, cost_1 = choose_order_cost(ls_1)
remaining_ops[1:ind-1] .= @view(split_candidates[1:ind-1]); remaining_ops[ind:end] .= @view(split_candidates[ind+1:end])
ls_2 = split_loopset(ls, remaining_ops)
order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2, cost_2 = choose_order(ls_2)
order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2, cost_2 = choose_order_cost(ls_2)
if cost_1 + cost_2 < cost_fused
ls_2_lowered = if length(remaining_ops) > 1
lower_and_split_loops(ls_2)
else
lower(ls_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2)
end
Expr(
return Expr(
:block,
ls.preamble,
lower(ls_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1),
Expand Down

0 comments on commit 8771d75

Please sign in to comment.