Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import Enzyme:
eltype,
API,
TypeTree,
typetree,
typetree, typetree_total,
TypeTreeTable,
only!,
shift!,
Expand Down Expand Up @@ -1114,7 +1114,7 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV

byref = arg.cc

rest = copy(typetree(arg.typ, ctx, dl))
rest = copy(typetree_total(job, arg.typ, ctx, dl))

if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF
# adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader
Expand All @@ -1139,7 +1139,7 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
if sret !== nothing
idx = 0
if !in(0, parmsRemoved)
rest = typetree(sret, ctx, dl)
rest = typetree_total(job, sret, ctx, dl)
push!(
parameter_attributes(f, idx + 1),
StringAttribute("enzyme_type", string(rest)),
Expand All @@ -1161,12 +1161,12 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType()
@assert !retRemoved
rest = if llRT == Ptr{RT}
typeTree = copy(typetree(RT, ctx, dl))
typeTree = copy(typetree_total(job, RT, ctx, dl))
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
only!(typeTree, -1)
typeTree
else
typetree(RT, ctx, dl)
typetree_total(job, RT, ctx, dl)
end
push!(return_attributes(f), StringAttribute("enzyme_type", string(rest)))
end
Expand Down Expand Up @@ -2355,7 +2355,7 @@ function enzyme!(
else
error("illegal annotation type $T")
end
typeTree = typetree(source_typ, ctx, dl, seen)
typeTree = typetree_total(job, source_typ, ctx, dl, seen)
if isboxed
typeTree = copy(typeTree)
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
Expand Down Expand Up @@ -2397,7 +2397,7 @@ function enzyme!(
in(Any, actualRetType.parameters)
TypeTree()
else
typeTree = typetree(actualRetType, ctx, dl, seen)
typeTree = typetree_total(job, actualRetType, ctx, dl, seen)
if !isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)
typeTree = copy(typeTree)
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
Expand Down Expand Up @@ -3539,6 +3539,7 @@ end

# Modified from GPUCompiler/src/irgen.jl:365 lower_byval
function lower_convention(
@nospecialize(job::GPUCompiler.CompilerJob),
@nospecialize(functy::Type),
mod::LLVM.Module,
entry_f::LLVM.Function,
Expand Down Expand Up @@ -3771,7 +3772,7 @@ function lower_convention(
metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
end

typeTree = copy(typetree(actualRetType, ctx, dl, seen))
typeTree = copy(typetree_total(job, actualRetType, ctx, dl, seen))
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
only!(typeTree, -1)
metadata(sretPtr)["enzyme_type"] = to_md(typeTree, ctx)
Expand Down Expand Up @@ -3810,8 +3811,7 @@ function lower_convention(
metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
end
ctx = LLVM.context(entry_f)

typeTree = copy(typetree(arg.typ, ctx, dl, seen))
typeTree = copy(typetree_total(job, arg.typ, ctx, dl, seen))
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
only!(typeTree, -1)
metadata(ptr)["enzyme_type"] = to_md(typeTree, ctx)
Expand All @@ -3825,7 +3825,7 @@ function lower_convention(
parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots),
StringAttribute(
"enzyme_type",
string(typetree(arg.typ, ctx, dl, seen)),
string(typetree_total(job, arg.typ, ctx, dl, seen)),
),
)
push!(
Expand All @@ -3846,7 +3846,7 @@ function lower_convention(
wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm)
ctx = LLVM.context(wrapparm)
push!(wrapper_args, wrapparm)
typeTree = copy(typetree(arg.typ, ctx, dl, seen))
typeTree = copy(typetree_total(job, arg.typ, ctx, dl, seen))
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
only!(typeTree, -1)
push!(
Expand Down Expand Up @@ -3968,7 +3968,7 @@ function lower_convention(
return_attributes(wrapper_f),
StringAttribute(
"enzyme_type",
string(typetree(actualRetType, ctx, dl, seen)),
string(typetree_total(job, actualRetType, ctx, dl, seen)),
),
)
push!(
Expand All @@ -3994,7 +3994,7 @@ function lower_convention(
return_attributes(wrapper_f),
StringAttribute(
"enzyme_type",
string(typetree(actualRetType, ctx, dl, seen)),
string(typetree_total(job, actualRetType, ctx, dl, seen)),
),
)
push!(
Expand Down Expand Up @@ -4023,7 +4023,7 @@ function lower_convention(
return_attributes(wrapper_f),
StringAttribute(
"enzyme_type",
string(typetree(eltype(RetActivity), ctx, dl, seen)),
string(typetree_total(job, eltype(RetActivity), ctx, dl, seen)),
),
)
push!(
Expand Down Expand Up @@ -4061,7 +4061,7 @@ function lower_convention(
return_attributes(wrapper_f),
StringAttribute(
"enzyme_type",
string(typetree(actualRetType, ctx, dl, seen)),
string(typetree_total(job, actualRetType, ctx, dl, seen)),
),
)
push!(
Expand Down Expand Up @@ -4591,6 +4591,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT

if state.lowerConvention
primalf, returnRoots, boxedArgs, loweredArgs, actualRetType = lower_convention(
job,
source_sig,
mod,
primalf,
Expand Down Expand Up @@ -4712,7 +4713,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
source_typ
end

ec = typetree(source_typ, ctx, dl, seen)
ec = typetree_total(job, source_typ, ctx, dl, seen)
if byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF
ec = copy(ec)
merge!(ec, TypeTree(API.DT_Pointer, ctx))
Expand Down Expand Up @@ -4751,7 +4752,7 @@ end
)
else
metadata(inst)["enzyme_type"] =
to_md(typetree(Ptr{Cvoid}, ctx, dl, seen), ctx)
to_md(typetree_total(job, Ptr{Cvoid}, ctx, dl, seen), ctx)
end
end
end
Expand Down Expand Up @@ -4791,7 +4792,7 @@ end
)
if offset < sizeof(jTy) && isa(sz, LLVM.ConstantInt) && sizeof(jTy) - offset >= convert(Int, sz)
lim = convert(Int, sz)
md = to_fullmd(jTy, offset, lim)
md = Core._call_in_world_total(job.world, to_fullmd, jTy, offset, lim)
@assert byref == GPUCompiler.BITS_REF ||
byref == GPUCompiler.MUT_REF
metadata(inst)["enzyme_truetype"] = md
Expand Down
10 changes: 10 additions & 0 deletions src/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ end

const TypeTreeTable = IdDict{Any,Union{Nothing,TypeTree}}

"""
typetree_total(job, T, ctx, dl, seen=TypeTreeTable())

A wrapper around `typetree` that ensures the call happens in the correct world for GPUCompiler.
Useful when using typetree from a generated function since typetree is user-extendable.
"""
function typetree_total(@nospecialize(job::GPUCompiler.CompilerJob), @nospecialize(T), ctx, dl, seen=TypeTreeTable())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity could we call this in world instead of total?

return Core._call_in_world_total(job.world, typetree, T, ctx, dl)
end

"""
function typetree(T, ctx, dl, seen=TypeTreeTable())

Expand Down
Loading