Skip to content
41 changes: 28 additions & 13 deletions src/solver/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,27 +380,42 @@ for (fname, matrix_elty, vector_elty) in (
(:rocsolver_sgesvdj, :Float32, :Float32),
)
@eval begin
function gesvdj!(A::ROCMatrix{$matrix_elty}, abstol::$vector_elty, max_sweeps::Cint)
function gesvdj!(A::ROCMatrix{$matrix_elty}, abstol::$vector_elty, max_sweeps::Cint; jobu::Char='A',
jobvt::Char='A')
m, n = size(A)
lda = max(1, stride(A, 2))
dev_residual = ROCVector{$vector_elty}(undef, 1)

dev_n_sweeps = ROCVector{Cint}(undef, 1)

S = ROCArray{$vector_elty}(undef, min(m, n))
U = ROCMatrix{$matrix_elty}(undef, (m, min(m, n)))
ldu = m
@assert stride(U, 2) == ldu
V = ROCMatrix{$matrix_elty}(undef, (min(m, n), n))
ldv = min(m, n)
@assert stride(V, 2) == ldv

dev_info = ROCVector{Cint}(undef, 1)
ldv = min(m, n)
k=ldv
ldu = m

V = if jobvt === 'A'
ROCMatrix{$matrix_elty}(undef, (n, n))
elseif jobvt === 'S'
ROCMatrix{$matrix_elty}(undef, (n, k))
elseif jobvt === 'N' || jobvt === 'O'
C_NULL
else
error("jobvt must be one of 'A', 'S', 'O', or 'N'")
end
U = if jobu === 'A'
ROCMatrix{$matrix_elty}(undef, (m, m))
elseif jobu === 'S'
ROCMatrix{$matrix_elty}(undef, (m, k))
elseif jobu === 'N' || jobu === 'O'
C_NULL
else
error("jobu must be one of 'A', 'S', 'O', or 'N'")
end


$fname(
rocBLAS.handle(),
rocblas_svect_singular,
rocblas_svect_singular,
jobu,
jobvt,
m, n, A, lda,
abstol,
dev_residual,
Expand All @@ -420,7 +435,7 @@ for (fname, matrix_elty, vector_elty) in (
info = AMDGPU.@allowscalar dev_info[1]
AMDGPU.unsafe_free!(dev_info)

U, S, V', residual, n_sweeps, info
U, S, (jobvt === 'N' || jobvt === 'O') ? V : V', residual, n_sweeps, info
end
end
end
Expand Down