diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 5333aea03caef4..62e9c7f1ded658 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -100,9 +100,9 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I end end -# even when the allocation contains an uninitialized field, we try an extra effort to check -# if this load at `idx` have any "safe" `setfield!` calls that define the field -function has_safe_def( +# even when the allocation contains an uninitialized field, we try an extra effort to +# check if all loads have "safe" `setfield!` calls that define the uninitialized field +function has_safe_def_for_undef_field( ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, newidx::Int, idx::Int) def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx) @@ -207,14 +207,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA end function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), - @nospecialize(typeconstraint)) - callback = function (@nospecialize(pi), @nospecialize(idx)) - if isa(pi, PiNode) - typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ)) + @nospecialize(typeconstraint), @nospecialize(callback = nothing)) + newcallback = function (@nospecialize(x), @nospecialize(idx)) + if isa(x, PiNode) + typeconstraint = typeintersect(typeconstraint, widenconst(x.typ)) end + callback === nothing || callback(x, idx) return false end - def = simple_walk(compact, defssa, callback) + def = simple_walk(compact, defssa, newcallback) return Pair{Any, Any}(def, typeconstraint) end @@ -224,7 +225,9 @@ end Starting at `val` walk use-def chains to get all the leaves feeding into this `val` (pruning those leaves rules out by path conditions). """ -function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint)) +function walk_to_defs(compact::IncrementalCompact, + @nospecialize(defssa), @nospecialize(typeconstraint), + @nospecialize(callback = nothing)) visited_phinodes = AnySSAValue[] isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes def = compact[defssa] @@ -260,7 +263,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe val = OldSSAValue(val.id) end if isa(val, AnySSAValue) - new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint) + new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback) if isa(new_def, AnySSAValue) if !haskey(visited_constraints, new_def) push!(worklist_defs, new_def) @@ -721,10 +724,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) continue end if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}() end - mid, defuse = get!(defuses, defidx) do - SPCSet(), SSADefUse() + mid, defuse, phidefs = get!(defuses, defidx) do + SPCSet(), SSADefUse(), PhiDefs(nothing) end push!(defuse.ccall_preserve_uses, idx) union!(mid, intermediaries) @@ -779,16 +782,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) # Mutable stuff here isa(def, SSAValue) || continue if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}() end - mid, defuse = get!(defuses, def.id) do - SPCSet(), SSADefUse() + mid, defuse, phidefs = get!(defuses, def.id) do + SPCSet(), SSADefUse(), PhiDefs(nothing) end if is_setfield push!(defuse.defs, idx) else push!(defuse.uses, idx) end + defval = compact[def] + if isa(defval, PhiNode) + phicallback = function (@nospecialize(x), @nospecialize(ssa)) + push!(intermediaries, ssa.id) + return false + end + defs, _ = walk_to_defs(compact, def, struct_typ, phicallback) + if _any(@nospecialize(d)->!isa(d, SSAValue), defs) + delete!(defuses, def.id) + continue + end + phidefs[] = Int[(def::SSAValue).id for def in defs] + end union!(mid, intermediaries) end continue @@ -848,43 +864,72 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) end end +# TODO: +# - run mutable SROA on the same IR as when we collect information about mutable allocations +# - simplify and improve the eliminability check below using an escape analysis + +const PhiDefs = RefValue{Union{Nothing,Vector{Int}}} + function sroa_mutables!(ir::IRCode, - defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, + defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int}, nested_loads::NestedLoads) local domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable` local any_eliminated = false # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield` - for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true) + for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true) intermediaries = collect(intermediaries) + phidefs = phidefs[] # Check if there are any uses we did not account for. If so, the variable # escapes and we cannot eliminate the allocation. This works, because we're guaranteed # not to include any intermediaries that have dead uses. As a result, missing uses will only ever # show up in the nuses_total count. - nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) + nleaves = count_leaves(defuse) + if phidefs !== nothing + # if this defines ϕ, we also track leaves of all predecessors as well + # FIXME this doesn't work when any predecessor is used by another ϕ-node + for pidx in phidefs + haskey(defuses, pidx) || continue + pdefuse = defuses[pidx][2] + nleaves += count_leaves(pdefuse) + end + end nuses = 0 for idx in intermediaries nuses += used_ssas[idx] end - nuses_total = used_ssas[idx] + nuses - length(intermediaries) + nuses -= length(intermediaries) + nuses_total = used_ssas[idx] + nuses + if phidefs !== nothing + for pidx in phidefs + # NOTE we don't need to accout for intermediates for this predecessor here, + # since they are already included in intermediates of this ϕ-node + # FIXME this doesn't work when any predecessor is used by another ϕ-node + nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself + end + end nleaves == nuses_total || continue # Find the type for this allocation defexpr = ir[SSAValue(idx)] - isa(defexpr, Expr) || continue - if !isexpr(defexpr, :new) - if is_known_call(defexpr, getfield, ir) - val = defexpr.args[2] - if isa(val, SSAValue) - struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) - if ismutabletype(struct_typ) - record_nested_load!(nested_mloads, idx) - end + if isa(defexpr, Expr) + @assert phidefs === nothing + if !isexpr(defexpr, :new) + maybe_record_nested_load!(nested_mloads, ir, idx) + continue + end + elseif isa(defexpr, PhiNode) + phidefs === nothing && continue + for pidx in phidefs + pexpr = ir[SSAValue(pidx)] + if !isexpr(pexpr, :new) + maybe_record_nested_load!(nested_mloads, ir, pidx) + @goto skip end end + else continue end - newidx = idx - typ = ir.stmts[newidx][:type] + typ = ir.stmts[idx][:type] if isa(typ, UnionAll) typ = unwrap_unionall(typ) end @@ -896,25 +941,29 @@ function sroa_mutables!(ir::IRCode, fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] all_forwarded = true for use in defuse.uses - stmt = ir[SSAValue(use)] # == `getfield` call - # We may have discovered above that this use is dead - # after the getfield elim of immutables. In that case, - # it would have been deleted. That's fine, just ignore - # the use in that case. - if stmt === nothing + eliminable = check_use_eliminability!(fielddefuse, ir, use, typ) + if eliminable === nothing + # We may have discovered above that this use is dead + # after the getfield elim of immutables. In that case, + # it would have been deleted. That's fine, just ignore + # the use in that case. all_forwarded = false continue + elseif !eliminable + @goto skip end - field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ) - field === nothing && @goto skip - push!(fielddefuse[field].uses, use) end for def in defuse.defs - stmt = ir[SSAValue(def)]::Expr # == `setfield!` call - field = try_compute_fieldidx_stmt(ir, stmt, typ) - field === nothing && @goto skip - isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error - push!(fielddefuse[field].defs, def) + check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip + end + if phidefs !== nothing + for pidx in phidefs + haskey(defuses, pidx) || continue + pdefuse = defuses[pidx][2] + for pdef in pdefuse.defs + check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip + end + end end # Check that the defexpr has defined values for all the fields # we're accessing. In the future, we may want to relax this, @@ -925,7 +974,13 @@ function sroa_mutables!(ir::IRCode, for fidx in 1:ndefuse du = fielddefuse[fidx] isempty(du.uses) && continue - push!(du.defs, newidx) + if phidefs === nothing + push!(du.defs, idx) + else + for pidx in phidefs + push!(du.defs, pidx) + end + end ldu = compute_live_ins(ir.cfg, du) if isempty(ldu.live_in_bbs) phiblocks = Int[] @@ -935,10 +990,24 @@ function sroa_mutables!(ir::IRCode, end allblocks = sort(vcat(phiblocks, ldu.def_bbs)) blocks[fidx] = phiblocks, allblocks - if fidx + 1 > length(defexpr.args) - for use in du.uses + if phidefs !== nothing + # check if all predecessors have safe definitions + for pidx in phidefs + newexpr = ir[SSAValue(pidx)]::Expr # == new(...) + if fidx + 1 > length(newexpr.args) # this field can be undefined + domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) + for use in du.uses + has_safe_def_for_undef_field(ir, domtree, allblocks, du, pidx, use) || @goto skip + end + end + end + else + newexpr = defexpr::Expr # == new(...) + if fidx + 1 > length(newexpr.args) # this field can be undefined domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip + for use in du.uses + has_safe_def_for_undef_field(ir, domtree, allblocks, du, idx, use) || @goto skip + end end end end @@ -983,9 +1052,16 @@ function sroa_mutables!(ir::IRCode, end end end - for stmt in du.defs - stmt == newidx && continue - ir[SSAValue(stmt)] = nothing + if isa(defexpr, PhiNode) + ir[SSAValue(idx)] = nothing + for pidx in phidefs::Vector{Int} + used_ssas[pidx] -= 1 + end + else + for stmt in du.defs + stmt == idx && continue + ir[SSAValue(stmt)] = nothing + end end end preserve_uses === nothing && continue @@ -993,7 +1069,7 @@ function sroa_mutables!(ir::IRCode, # this means all ccall preserves have been replaced with forwarded loads # so we can potentially eliminate the allocation, otherwise we must preserve # the whole allocation. - push!(intermediaries, newidx) + push!(intermediaries, idx) end # Insert the new preserves for (use, new_preserves) in preserve_uses @@ -1005,6 +1081,42 @@ function sroa_mutables!(ir::IRCode, return any_eliminated ? sroa_pass!(compact!(ir), false) : ir end +count_leaves(defuse::SSADefUse) = + length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) + +function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int) + defexpr = ir[SSAValue(idx)] + if is_known_call(defexpr, getfield, ir) + val = defexpr.args[2] + if isa(val, SSAValue) + struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) + if ismutabletype(struct_typ) + record_nested_load!(nested_mloads, idx) + end + end + end +end + +function check_use_eliminability!(fielddefuse::Vector{SSADefUse}, + ir::IRCode, useidx::Int, struct_typ::DataType) + stmt = ir[SSAValue(useidx)] # == `getfield` call + stmt === nothing && return nothing + field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ) + field === nothing && return false + push!(fielddefuse[field].uses, useidx) + return true +end + +function check_def_eliminability!(fielddefuse::Vector{SSADefUse}, + ir::IRCode, defidx::Int, struct_typ::DataType) + stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call + field = try_compute_fieldidx_stmt(ir, stmt, struct_typ) + field === nothing && return false + isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error + push!(fielddefuse[field].defs, defidx) + return true +end + function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) newex = Expr(:foreigncall) nccallargs = length(origex.args[3]::SimpleVector) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 2d1d85e3df97fd..5dcad9e6e83f03 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -230,7 +230,7 @@ let src = code_typed1((Any,Any,Any)) do x, y, z end end # FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until -# any nested mutable `getfield` calls become no longer eliminatable: +# any nested mutable `getfield` calls become no longer eliminable: # it's probably not the most efficient option and we may want to introduce some sort of # alias analysis and eliminates all the loads at once. # mutable(immutable(...)) case @@ -308,6 +308,188 @@ let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it sho @test !any(isnew, src.code) end +# ϕ-allocation elimination +# ------------------------ +mutable struct MutableSome + x::Any + MutableSome(@nospecialize x) = new(x) + MutableSome() = new() +end +Base.getindex(s::MutableSome) = s.x +Base.setindex!(s::MutableSome, @nospecialize x) = s.x = x +@testset "mutable ϕ-allocation elimination" begin + # safe cases + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end + let src = code_typed1((Bool,Bool,Any,Any,Any)) do cond1, cond2, x, y, z + if cond1 + ϕ = MutableSome(x) + elseif cond2 + ϕ = MutableSome(y) + else + ϕ = MutableSome(z) + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(4) in x.values && + #=y=# Core.Argument(5) in x.values && + #=z=# Core.Argument(6) in x.values + end == 1 + end + let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + ϕ[] = z + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 + end + let src = code_typed1((Bool,Any,Any,)) do cond, x, y + if cond + ϕ = MutableSome(x) + out1 = ϕ[] + else + ϕ = MutableSome(y) + out1 = ϕ[] + end + out2 = ϕ[] + out1, out2 + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 2 + end + let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = MutableSome(x) + out1 = ϕ[] + else + ϕ = MutableSome(y) + out1 = ϕ[] + ϕ[] = z + end + out2 = ϕ[] + out1, out2 + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 + end + + # unsafe cases + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + end + some_escape(ϕ) + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + some_escape(ϕ) + else + ϕ = MutableSome(y) + end + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,)) do cond, x + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome() + end + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome() + ϕ[] = y + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end + + # FIXME allocation forming multiple ϕ + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ2 = ϕ1 = MutableSome(x) + else + ϕ2 = ϕ1 = MutableSome(y) + end + ϕ1[], ϕ2[] + end + @test_broken !any(isnew, src.code) + @test_broken count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end +end +function mutable_ϕ_elim(x, xs) + r = Ref(x) + for x in xs + r = Ref(x) + end + return r[] +end +let xs = String[string(gensym()) for _ in 1:100] + mutable_ϕ_elim("init", xs) + @test @allocated(mutable_ϕ_elim("init", xs)) == 0 +end + # should work nicely with inlining to optimize away a complicated case # adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B struct Point