Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions gen/rocrand/generator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using Clang.Generators
using JuliaFormatter

include_dir = normpath(joinpath(ENV["ROCM_PATH"], "include"))
rocrand_dir = joinpath(include_dir, "rocrand")
options = load_options("rocrand/rocrand-generator.toml")

args = get_default_args()
push!(args, "-I$include_dir")

rocrand_h = read(joinpath(rocrand_dir, "rocrand.h"), String)
open("./rocrand.h", "w") do io
println(io, """
#include <stddef.h>

typedef void* hipStream_t;
typedef struct { unsigned int x, y, z, w; } uint4;
""")
print(io, rocrand_h)
end
headers = [
"./rocrand.h"
]

ctx = create_context(headers, args, options)

# build without printing so we can do custom rewriting
build!(ctx, BUILDSTAGE_NO_PRINTING)

# custom rewriter
function rewrite!(e::Expr)
if e.head === :const
@assert Meta.isexpr(e.args[1], :(=))
rhs = e.args[1].args[2]
if Meta.isexpr(rhs, :call)
if rhs.args[1] == :(*) && rhs.args[3] == :f
e.args[1].args[2] = :(Float32($(rhs.args[2])))
elseif rhs.args[1] == :(Cuint)
e.args[1].args[2] = :($(rhs.args[2]) % Cuint)
end
end
return e
end
(e.head === :function && Meta.isexpr(e.args[1], :call)) || return e
f = e.args[1].args[1]
if !(f isa Symbol)
@assert f in (:(Base.getproperty), :(Base.setproperty!), :(Base.propertynames))
return e
end
stmts = e.args[2].args
map!(stmts, stmts) do ex
Meta.isexpr(ex, :macrocall) || return ex
ex.args[1] === Symbol("@ccall") || return ex
# TODO: should this be `@gcsafe_ccall`?
# ex.args[1] = Symbol("@gcsafe_ccall")
Expr(:macrocall, Symbol("@check"), nothing, ex)
end
pushfirst!(stmts, :(AMDGPU.prepare_state()))
return e
end

function rewrite!(dag::ExprDAG)
for node in get_nodes(dag)
for expr in get_exprs(node)
rewrite!(expr)
end
end
end

rewrite!(ctx.dag)

# print
build!(ctx, BUILDSTAGE_PRINTING_ONLY)

path = options["general"]["output_file_path"]
format_file(path, YASStyle())
15 changes: 15 additions & 0 deletions gen/rocrand/rocrand-generator.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[general]
library_name = "librocrand"
output_file_path = "../src/rand/librocrand.jl"
export_symbol_prefixes = []
print_using_CEnum = false
output_ignorelist = [
"(__)?hip.*",
"(__)?HIP.*",
"rocrand_status",
"half",
"SKEIN_KS_PARITY64",
]

[codegen]
use_ccall_macro = true
14 changes: 14 additions & 0 deletions src/rand/error.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
export ROCRANDError

import .AMDGPU: @check, check
using CEnum: @cenum

@cenum rocrand_status::UInt32 begin
ROCRAND_STATUS_SUCCESS = 0
ROCRAND_STATUS_VERSION_MISMATCH = 100
ROCRAND_STATUS_NOT_CREATED = 101
ROCRAND_STATUS_ALLOCATION_FAILED = 102
ROCRAND_STATUS_TYPE_ERROR = 103
ROCRAND_STATUS_OUT_OF_RANGE = 104
ROCRAND_STATUS_LENGTH_NOT_MULTIPLE = 105
ROCRAND_STATUS_DOUBLE_PRECISION_REQUIRED = 106
ROCRAND_STATUS_LAUNCH_FAILURE = 107
ROCRAND_STATUS_INTERNAL_ERROR = 108
end

struct ROCRANDError <: Exception
code::rocrand_status
Expand Down
Loading