Skip to content

Commit

Permalink
Constant cse (#865)
Browse files Browse the repository at this point in the history
* Constant cse

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix test

* traced if constants

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Mar 9, 2025
1 parent 8e1d45d commit c705436
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 9 deletions.
64 changes: 61 additions & 3 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,56 @@ struct Token
mlir_data::MLIR.IR.Value
end

function activate_constant_context!(blk::MLIR.IR.Block)
stack = get!(task_local_storage(), :entry_block) do
return Tuple{MLIR.IR.Block,Dict{MLIR.IR.Attribute,TracedRArray}}[]
end
Base.push!(stack, (blk, Dict{MLIR.IR.Attribute,TracedRArray}()))
return nothing
end

function constant_context(; throw_error::Core.Bool=true)
return last(task_local_storage(:entry_block))
end

function deactivate_constant_context!(blk::MLIR.IR.Block)
constant_context()[1] == blk || error("Deactivating wrong block")
return Base.pop!(task_local_storage(:entry_block))
end

# constant ops
@noinline function constant(
x::DenseArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
) where {T,N}
value = MLIR.IR.DenseElementsAttribute(x)
output = mlir_type(TracedRArray{T,N}, size(x))
res = MLIR.IR.result(stablehlo.constant(; output, value, location))
return TracedRArray{T,N}((), res, size(x))
constants = constant_context()[2]
if haskey(constants, value)
return constants[value]
else
output = mlir_type(TracedRArray{T,N}, size(x))

op_ty_results = MLIR.IR.Type[output]
operands = MLIR.IR.Value[]
owned_regions = MLIR.IR.Region[]
successors = MLIR.IR.Block[]
attributes = MLIR.IR.NamedAttribute[MLIR.Dialects.namedattribute("value", value),]

cstop = MLIR.IR.create_operation(
"stablehlo.constant",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)

res = MLIR.IR.result(cstop)
tres = TracedRArray{T,N}((), res, size(x))
constants[value] = tres
return tres
end
end

@noinline function constant(
Expand Down Expand Up @@ -1764,6 +1806,7 @@ end
true_fn_args = true_fn_names[1]

MLIR.IR.activate!(true_fn_body)
Ops.activate_constant_context!(true_fn_body)
tb_result = try
for (i, arg) in enumerate(tb_linear_args)
# find the right path to index the traced arg.
Expand All @@ -1787,6 +1830,7 @@ end
end
Reactant.call_with_reactant(true_fn, tb_traced_args...)
finally
Ops.deactivate_constant_context!(true_fn_body)
MLIR.IR.deactivate!(true_fn_body)
end

Expand Down Expand Up @@ -1827,6 +1871,7 @@ end

false_fn_args = false_fn_names[1]
MLIR.IR.activate!(false_fn_body)
Ops.activate_constant_context!(false_fn_body)
fb_result = try
for (i, arg) in enumerate(fb_linear_args)
# find the right path to index the traced arg.
Expand All @@ -1850,6 +1895,7 @@ end
end
Reactant.call_with_reactant(false_fn, fb_traced_args...)
finally
Ops.deactivate_constant_context!(false_fn_body)
MLIR.IR.deactivate!(false_fn_body)
end

Expand Down Expand Up @@ -1928,6 +1974,7 @@ end

# finalize the true branch by adding the missing values
MLIR.IR.activate!(true_fn_body)
Ops.activate_constant_context!(true_fn_body)
tb_corrected_linear_results = Reactant.TracedType[]
try
for (i, path) in enumerate(tb_paths)
Expand All @@ -1939,10 +1986,12 @@ end
end
finally
MLIR.IR.deactivate!(true_fn_body)
Ops.deactivate_constant_context!(true_fn_body)
end

# finalize the false branch by adding the missing values
MLIR.IR.activate!(false_fn_body)
Ops.activate_constant_context!(false_fn_body)
fb_corrected_linear_results = Reactant.TracedType[]
try
for (i, path) in enumerate(fb_paths)
Expand All @@ -1954,6 +2003,7 @@ end
end
finally
MLIR.IR.deactivate!(false_fn_body)
Ops.deactivate_constant_context!(false_fn_body)
end

# All MissingTracedValues must be replaced with zeroes
Expand All @@ -1968,19 +2018,23 @@ end
res = if tr isa MissingTracedValue
@assert !(fr isa MissingTracedValue)
MLIR.IR.activate!(true_fn_body)
Ops.activate_constant_context!(true_fn_body)
try
tb_corrected_linear_results[i] = zero(fr)
finally
MLIR.IR.deactivate!(true_fn_body)
Ops.deactivate_constant_context!(true_fn_body)
end
fr
elseif fr isa MissingTracedValue
@assert !(tr isa MissingTracedValue)
MLIR.IR.activate!(false_fn_body)
Ops.activate_constant_context!(false_fn_body)
try
fb_corrected_linear_results[i] = zero(tr)
finally
MLIR.IR.deactivate!(false_fn_body)
Ops.deactivate_constant_context!(false_fn_body)
end
tr
else
Expand All @@ -1993,6 +2047,7 @@ end
end

MLIR.IR.activate!(true_fn_body)
Ops.activate_constant_context!(true_fn_body)
try
vals = MLIR.IR.Value[
Reactant.TracedUtils.get_mlir_data(res) for
Expand All @@ -2001,9 +2056,11 @@ end
MLIR.Dialects.stablehlo.return_(vals)
finally
MLIR.IR.deactivate!(true_fn_body)
Ops.deactivate_constant_context!(true_fn_body)
end

MLIR.IR.activate!(false_fn_body)
Ops.activate_constant_context!(false_fn_body)
try
vals = MLIR.IR.Value[
Reactant.TracedUtils.get_mlir_data(res) for
Expand All @@ -2012,6 +2069,7 @@ end
MLIR.Dialects.stablehlo.return_(vals)
finally
MLIR.IR.deactivate!(false_fn_body)
Ops.deactivate_constant_context!(false_fn_body)
end

# With the corrected results, we can compile the true and false branches
Expand Down
2 changes: 2 additions & 0 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ function make_mlir_fn(

fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args])
push!(MLIR.IR.region(func, 1), fnbody)
Ops.activate_constant_context!(fnbody)

@assert MLIR.IR._has_block()

Expand All @@ -265,6 +266,7 @@ function make_mlir_fn(
end
finally
MLIR.IR.deactivate!(fnbody)
Ops.deactivate_constant_context!(fnbody)
end

# check which arguments have been mutated
Expand Down
22 changes: 16 additions & 6 deletions src/mlir/IR/Operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ This will return true if the dialect is loaded and the operation is registered w
is_registered(opname; context::Context=context()) =
API.mlirContextIsRegisteredOperation(context, opname)

function create_operation(
function create_operation_common(
name,
loc;
results=nothing,
Expand Down Expand Up @@ -320,10 +320,20 @@ function create_operation(
if mlirIsNull(op)
error("Create Operation '$name' failed")
end
res = Operation(op, true)
if _has_block()
push!(block(), res)
end
return res
return Operation(op, true)
end
end

function create_operation(args...; kwargs...)
res = create_operation_common(args...; kwargs...)
if _has_block()
push!(block(), res)
end
return res
end

function create_operation_at_front(args...; kwargs...)
res = create_operation_common(args...; kwargs...)
Base.pushfirst!(block(), res)
return res
end
19 changes: 19 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1101,3 +1101,22 @@ end
r = reduce(+, A; dims=dims, init=init)
@test r_hlo squeeze_dims(r)
end

@testset "const dedup" begin
x = Reactant.to_rarray([11, 12, 13, 14])
function const_dedup(x)
c1 = [1, 2, 3, 4]
y1 = (x .+ c1)
c2 = [1, 2, 3, 4]
y2 = (x .+ c2)
c1[1] = 6
return y1 .* y2 .* c1
end

mod = @code_hlo optimize = false const_dedup(x)
hlo_ir = repr(mod)
csts = collect(x for x in eachsplit(hlo_ir, "\n") if occursin("stablehlo.constant", x))
@test length(csts) == 2
@test occursin("1, 2, 3, 4", csts[1])
@test occursin("6, 2, 3, 4", csts[2])
end

0 comments on commit c705436

Please sign in to comment.