Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ end
# Has the runtime available and does not require special handling
uses_julia_runtime(@nospecialize(job::CompilerJob)) = false

# Should we emit code in imaging mode (i.e. without embedding concrete runtime addresses)?
imaging_mode(@nospecialize(job::CompilerJob)) = imaging_mode(job.config.target)
imaging_mode(@nospecialize(target::AbstractCompilerTarget)) = false

# Is it legal to run vectorization passes on this target
can_vectorize(@nospecialize(job::CompilerJob)) = false

Expand Down
11 changes: 9 additions & 2 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,8 @@ function compile_method_instance(@nospecialize(job::CompilerJob))

# set-up the compiler interface
debug_info_kind = llvm_debug_info(job)
imaging = imaging_mode(job)

cgparams = (;
track_allocations = false,
code_coverage = false,
Expand All @@ -725,6 +727,9 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
debug_info_kind = Cint(debug_info_kind),
safepoint_on_entry = can_safepoint(job),
gcstack_arg = false)
if :use_jlplt in fieldnames(Base.CodegenParams)
cgparams = (; cgparams..., use_jlplt = imaging)
end
if VERSION < v"1.12.0-DEV.1667"
cgparams = (; lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb), cgparams... )
end
Expand All @@ -748,6 +753,8 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
Metadata(ConstantInt(DEBUG_METADATA_VERSION()))
end

imaging_flag = imaging ? 1 : 0

native_code = if VERSION >= v"1.12.0-DEV.1823"
codeinfos = Any[]
for (ci, src) in populated
Expand All @@ -760,11 +767,11 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
elseif VERSION >= v"1.12.0-DEV.1667"
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint, Cint, Cint, Csize_t, Ptr{Cvoid}),
[job.source], ts_mod, Ref(params), CompilationPolicyExtern, #=imaging mode=# 0, #=external linkage=# 0, job.world, Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
[job.source], ts_mod, Ref(params), CompilationPolicyExtern, imaging_flag, #=external linkage=# 0, job.world, Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
else
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint, Cint, Cint, Csize_t),
[job.source], ts_mod, Ref(params), CompilationPolicyExtern, #=imaging mode=# 0, #=external linkage=# 0, job.world)
[job.source], ts_mod, Ref(params), CompilationPolicyExtern, imaging_flag, #=external linkage=# 0, job.world)
end
@assert native_code != C_NULL

Expand Down
62 changes: 61 additions & 1 deletion src/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
end
run!(pb, mod, tm)
end

# Make sure any lingering TLS getters are rewritten even if upstream LLVM passes
# transformed them before the GPULowerPTLSPass had a chance to run.
if occursin("StaticCompilerTarget", string(typeof(job.config.target))) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the opposite of the code pattern we use, we should never hard code a specific target here, but rather expose hooks for a target to configure the pipeline

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would strongly prefer small, self contained PRs with tests for each features.

uses_julia_runtime(job)
lower_ptls!(mod)
end

optimize_module!(job, mod)
run!(DeadArgumentEliminationPass(), mod, tm)
return
Expand Down Expand Up @@ -405,7 +413,31 @@ function lower_ptls!(mod::LLVM.Module)

intrinsic = "julia.get_pgcstack"

if haskey(functions(mod), intrinsic)
# On host-style static targets we want a relocatable call into libjulia instead of
# embedding the pointer to the TLS getter. Replace the intrinsic with a declared
# libjulia call to avoid baking absolute addresses that crash in standalone binaries.
if haskey(functions(mod), intrinsic) &&
occursin("StaticCompilerTarget", string(typeof(job.config.target))) &&
uses_julia_runtime(job)

pgc_fn = functions(mod)[intrinsic]
jl_decl = if haskey(functions(mod), "jl_get_pgcstack")
functions(mod)["jl_get_pgcstack"]
else
LLVM.Function(mod, "jl_get_pgcstack", LLVM.FunctionType(LLVM.PointerType()))
end

for use in uses(pgc_fn)
call = user(use)::LLVM.CallInst
@dispose builder=IRBuilder() begin
position!(builder, call)
repl = call!(builder, function_type(jl_decl), jl_decl, LLVM.Value[])
replace_uses!(call, repl)
end
erase!(call)
changed = true
end
elseif haskey(functions(mod), intrinsic)
ptls_getter = functions(mod)[intrinsic]

for use in uses(ptls_getter)
Expand All @@ -419,6 +451,34 @@ function lower_ptls!(mod::LLVM.Module)
end
end

# Newer Julia versions sometimes lower the TLS getter to an inttoptr call that bakes
# the address of `jl_get_pgcstack_static` into the IR. Rewrite those calls as well to
# make sure we always end up with a relocatable reference into libjulia when the
# runtime is linked.
if uses_julia_runtime(job) && occursin("StaticCompilerTarget", string(typeof(job.config.target)))
jl_decl = if haskey(functions(mod), "jl_get_pgcstack")
functions(mod)["jl_get_pgcstack"]
else
LLVM.Function(mod, "jl_get_pgcstack", LLVM.FunctionType(LLVM.PointerType()))
end

for f in functions(mod), bb in blocks(f), inst in instructions(bb)
inst isa LLVM.CallInst || continue

callee = LLVM.called_operand(inst)
if callee isa LLVM.ConstantExpr && occursin("inttoptr", string(callee)) &&
occursin("pgcstack", string(inst))
@dispose builder=IRBuilder() begin
position!(builder, inst)
repl = call!(builder, function_type(jl_decl), jl_decl, LLVM.Value[])
replace_uses!(inst, repl)
end
erase!(inst)
changed = true
end
end
end

return changed
end
GPULowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!)
3 changes: 2 additions & 1 deletion test/native.jl
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ end
Native.code_execution(mod.foobar, Tuple{Ptr{Int}})) do msg
if VERSION >= v"1.11-"
occursin("invalid LLVM IR", msg) &&
occursin(GPUCompiler.LAZY_FUNCTION, msg) &&
(occursin(GPUCompiler.LAZY_FUNCTION, msg) ||
occursin(GPUCompiler.RUNTIME_FUNCTION, msg)) &&
occursin("call to time", msg) &&
occursin("[1] foobar", msg)
else
Expand Down