Skip to content

Commit

Permalink
optimizer: enable SROA of mutable φ-nodes
Browse files Browse the repository at this point in the history
This commit allows elimination of mutable φ-node (and its predecessor mutables allocations).
As an contrived example, it allows this `mutable_ϕ_elim(::String, ::Vector{String})`
to run without any allocations at all:
```julia
function mutable_ϕ_elim(x, xs)
    r = Ref(x)
    for x in xs
        r = Ref(x)
    end
    return r[]
end

let xs = String[string(gensym()) for _ in 1:100]
    mutable_ϕ_elim("init", xs)
    @test @allocated(mutable_ϕ_elim("init", xs)) == 0
end
```

This mutable ϕ-node elimination is still limited though.
Most notably, the current implementation doesn't work if a mutable
allocation forms multiple ϕ-nodes, since we check allocation eliminability
(i.e. escapability) by counting usages counts and thus it's hard to
reason about multiple ϕ-nodes at a time.
For example, currently mutable allocations involved in cases like below
will still not be eliminated:
```julia
code_typed((Bool,String,String),) do cond, x, y
    if cond
        ϕ2 = ϕ1 = Ref(x)
    else
        ϕ2 = ϕ1 = Ref(y)
    end
    ϕ1[], ϕ2[]
end

\# more realistic example
mutable struct Point{T}
    x::T
    y::T
end
add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y)
function compute(a::Point{ComplexF64}, b::Point{ComplexF64})
    for i in 0:(100000000-1)
        a = add(add(a, b), b)
    end
    a.x, a.y
end
```

I'd say this limitation should be addressed by first introducing a better
abstraction for reasoning escape information. More specifically, I'd like
introduce EscapeAnalysis.jl into Julia base first, and then gradually
adapt it to improve our SROA pass, since EA will allow us to reason about
all escape information imposed on whatever object more easily and should
help us get rid of the complexities of our current SROA implementation.

For now, I'd like to get in this enhancement even though it has the
limitation elaborated above, as far as this commit doesn't introduce
latency problem (which is unlikely).
  • Loading branch information
aviatesk committed Dec 23, 2021
1 parent 250059f commit c19c4e4
Show file tree
Hide file tree
Showing 2 changed files with 348 additions and 54 deletions.
218 changes: 165 additions & 53 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
end
end

# even when the allocation contains an uninitialized field, we try an extra effort to check
# if this load at `idx` have any "safe" `setfield!` calls that define the field
function has_safe_def(
# even when the allocation contains an uninitialized field, we try an extra effort to
# check if all loads have "safe" `setfield!` calls that define the uninitialized field
function has_safe_def_for_undef_field(
ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse,
newidx::Int, idx::Int)
def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx)
Expand Down Expand Up @@ -207,14 +207,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@nospecialize(typeconstraint))
callback = function (@nospecialize(pi), @nospecialize(idx))
if isa(pi, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
@nospecialize(typeconstraint), @nospecialize(callback = nothing))
newcallback = function (@nospecialize(x), @nospecialize(idx))
if isa(x, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(x.typ))
end
callback === nothing || callback(x, idx)
return false
end
def = simple_walk(compact, defssa, callback)
def = simple_walk(compact, defssa, newcallback)
return Pair{Any, Any}(def, typeconstraint)
end

Expand All @@ -224,7 +225,9 @@ end
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
(pruning those leaves rules out by path conditions).
"""
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
function walk_to_defs(compact::IncrementalCompact,
@nospecialize(defssa), @nospecialize(typeconstraint),
@nospecialize(callback = nothing))
visited_phinodes = AnySSAValue[]
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
def = compact[defssa]
Expand Down Expand Up @@ -260,7 +263,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
val = OldSSAValue(val.id)
end
if isa(val, AnySSAValue)
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint)
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback)
if isa(new_def, AnySSAValue)
if !haskey(visited_constraints, new_def)
push!(worklist_defs, new_def)
Expand Down Expand Up @@ -721,10 +724,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
continue
end
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
end
mid, defuse = get!(defuses, defidx) do
SPCSet(), SSADefUse()
mid, defuse, phidefs = get!(defuses, defidx) do
SPCSet(), SSADefUse(), PhiDefs(nothing)
end
push!(defuse.ccall_preserve_uses, idx)
union!(mid, intermediaries)
Expand Down Expand Up @@ -779,16 +782,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
# Mutable stuff here
isa(def, SSAValue) || continue
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
end
mid, defuse = get!(defuses, def.id) do
SPCSet(), SSADefUse()
mid, defuse, phidefs = get!(defuses, def.id) do
SPCSet(), SSADefUse(), PhiDefs(nothing)
end
if is_setfield
push!(defuse.defs, idx)
else
push!(defuse.uses, idx)
end
defval = compact[def]
if isa(defval, PhiNode)
phicallback = function (@nospecialize(x), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
defs, _ = walk_to_defs(compact, def, struct_typ, phicallback)
if _any(@nospecialize(d)->!isa(d, SSAValue), defs)
delete!(defuses, def.id)
continue
end
phidefs[] = Int[(def::SSAValue).id for def in defs]
end
union!(mid, intermediaries)
end
continue
Expand Down Expand Up @@ -848,43 +864,72 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
end
end

# TODO:
# - run mutable SROA on the same IR as when we collect information about mutable allocations
# - simplify and improve the eliminability check below using an escape analysis

const PhiDefs = RefValue{Union{Nothing,Vector{Int}}}

function sroa_mutables!(ir::IRCode,
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int},
nested_loads::NestedLoads)
local domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
local any_eliminated = false
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true)
intermediaries = collect(intermediaries)
phidefs = phidefs[]
# Check if there are any uses we did not account for. If so, the variable
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
# show up in the nuses_total count.
nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
nleaves = count_leaves(defuse)
if phidefs !== nothing
# if this defines ϕ, we also track leaves of all predecessors as well
# FIXME this doesn't work when any predecessor is used by another ϕ-node
for pidx in phidefs
haskey(defuses, pidx) || continue
pdefuse = defuses[pidx][2]
nleaves += count_leaves(pdefuse)
end
end
nuses = 0
for idx in intermediaries
nuses += used_ssas[idx]
end
nuses_total = used_ssas[idx] + nuses - length(intermediaries)
nuses -= length(intermediaries)
nuses_total = used_ssas[idx] + nuses
if phidefs !== nothing
for pidx in phidefs
# NOTE we don't need to accout for intermediates for this predecessor here,
# since they are already included in intermediates of this ϕ-node
# FIXME this doesn't work when any predecessor is used by another ϕ-node
nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself
end
end
nleaves == nuses_total || continue
# Find the type for this allocation
defexpr = ir[SSAValue(idx)]
isa(defexpr, Expr) || continue
if !isexpr(defexpr, :new)
if is_known_call(defexpr, getfield, ir)
val = defexpr.args[2]
if isa(val, SSAValue)
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
if ismutabletype(struct_typ)
record_nested_load!(nested_mloads, idx)
end
if isa(defexpr, Expr)
@assert phidefs === nothing
if !isexpr(defexpr, :new)
maybe_record_nested_load!(nested_mloads, ir, idx)
continue
end
elseif isa(defexpr, PhiNode)
phidefs === nothing && continue
for pidx in phidefs
pexpr = ir[SSAValue(pidx)]
if !isexpr(pexpr, :new)
maybe_record_nested_load!(nested_mloads, ir, pidx)
@goto skip
end
end
else
continue
end
newidx = idx
typ = ir.stmts[newidx][:type]
typ = ir.stmts[idx][:type]
if isa(typ, UnionAll)
typ = unwrap_unionall(typ)
end
Expand All @@ -896,25 +941,29 @@ function sroa_mutables!(ir::IRCode,
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
all_forwarded = true
for use in defuse.uses
stmt = ir[SSAValue(use)] # == `getfield` call
# We may have discovered above that this use is dead
# after the getfield elim of immutables. In that case,
# it would have been deleted. That's fine, just ignore
# the use in that case.
if stmt === nothing
eliminable = check_use_eliminability!(fielddefuse, ir, use, typ)
if eliminable === nothing
# We may have discovered above that this use is dead
# after the getfield elim of immutables. In that case,
# it would have been deleted. That's fine, just ignore
# the use in that case.
all_forwarded = false
continue
elseif !eliminable
@goto skip
end
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
field === nothing && @goto skip
push!(fielddefuse[field].uses, use)
end
for def in defuse.defs
stmt = ir[SSAValue(def)]::Expr # == `setfield!` call
field = try_compute_fieldidx_stmt(ir, stmt, typ)
field === nothing && @goto skip
isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error
push!(fielddefuse[field].defs, def)
check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip
end
if phidefs !== nothing
for pidx in phidefs
haskey(defuses, pidx) || continue
pdefuse = defuses[pidx][2]
for pdef in pdefuse.defs
check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip
end
end
end
# Check that the defexpr has defined values for all the fields
# we're accessing. In the future, we may want to relax this,
Expand All @@ -925,7 +974,13 @@ function sroa_mutables!(ir::IRCode,
for fidx in 1:ndefuse
du = fielddefuse[fidx]
isempty(du.uses) && continue
push!(du.defs, newidx)
if phidefs === nothing
push!(du.defs, idx)
else
for pidx in phidefs
push!(du.defs, pidx)
end
end
ldu = compute_live_ins(ir.cfg, du)
if isempty(ldu.live_in_bbs)
phiblocks = Int[]
Expand All @@ -935,10 +990,24 @@ function sroa_mutables!(ir::IRCode,
end
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
blocks[fidx] = phiblocks, allblocks
if fidx + 1 > length(defexpr.args)
for use in du.uses
if phidefs !== nothing
# check if all predecessors have safe definitions
for pidx in phidefs
newexpr = ir[SSAValue(pidx)]::Expr # == new(...)
if fidx + 1 > length(newexpr.args) # this field can be undefined
domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks))
for use in du.uses
has_safe_def_for_undef_field(ir, domtree, allblocks, du, pidx, use) || @goto skip
end
end
end
else
newexpr = defexpr::Expr # == new(...)
if fidx + 1 > length(newexpr.args) # this field can be undefined
domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks))
has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip
for use in du.uses
has_safe_def_for_undef_field(ir, domtree, allblocks, du, idx, use) || @goto skip
end
end
end
end
Expand Down Expand Up @@ -983,17 +1052,24 @@ function sroa_mutables!(ir::IRCode,
end
end
end
for stmt in du.defs
stmt == newidx && continue
ir[SSAValue(stmt)] = nothing
if isa(defexpr, PhiNode)
ir[SSAValue(idx)] = nothing
for pidx in phidefs::Vector{Int}
used_ssas[pidx] -= 1
end
else
for stmt in du.defs
stmt == idx && continue
ir[SSAValue(stmt)] = nothing
end
end
end
preserve_uses === nothing && continue
if all_forwarded
# this means all ccall preserves have been replaced with forwarded loads
# so we can potentially eliminate the allocation, otherwise we must preserve
# the whole allocation.
push!(intermediaries, newidx)
push!(intermediaries, idx)
end
# Insert the new preserves
for (use, new_preserves) in preserve_uses
Expand All @@ -1005,6 +1081,42 @@ function sroa_mutables!(ir::IRCode,
return any_eliminated ? sroa_pass!(compact!(ir), false) : ir
end

count_leaves(defuse::SSADefUse) =
length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)

function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int)
defexpr = ir[SSAValue(idx)]
if is_known_call(defexpr, getfield, ir)
val = defexpr.args[2]
if isa(val, SSAValue)
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
if ismutabletype(struct_typ)
record_nested_load!(nested_mloads, idx)
end
end
end
end

function check_use_eliminability!(fielddefuse::Vector{SSADefUse},
ir::IRCode, useidx::Int, struct_typ::DataType)
stmt = ir[SSAValue(useidx)] # == `getfield` call
stmt === nothing && return nothing
field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ)
field === nothing && return false
push!(fielddefuse[field].uses, useidx)
return true
end

function check_def_eliminability!(fielddefuse::Vector{SSADefUse},
ir::IRCode, defidx::Int, struct_typ::DataType)
stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call
field = try_compute_fieldidx_stmt(ir, stmt, struct_typ)
field === nothing && return false
isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error
push!(fielddefuse[field].defs, defidx)
return true
end

function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})
newex = Expr(:foreigncall)
nccallargs = length(origex.args[3]::SimpleVector)
Expand Down
Loading

0 comments on commit c19c4e4

Please sign in to comment.