diff --git a/src/solver/highlevel.jl b/src/solver/highlevel.jl index 9aef7ef5d..88c3541fc 100644 --- a/src/solver/highlevel.jl +++ b/src/solver/highlevel.jl @@ -380,27 +380,42 @@ for (fname, matrix_elty, vector_elty) in ( (:rocsolver_sgesvdj, :Float32, :Float32), ) @eval begin - function gesvdj!(A::ROCMatrix{$matrix_elty}, abstol::$vector_elty, max_sweeps::Cint) + function gesvdj!(A::ROCMatrix{$matrix_elty}, abstol::$vector_elty, max_sweeps::Cint; jobu::Char='A', + jobvt::Char='A') m, n = size(A) lda = max(1, stride(A, 2)) dev_residual = ROCVector{$vector_elty}(undef, 1) - dev_n_sweeps = ROCVector{Cint}(undef, 1) - S = ROCArray{$vector_elty}(undef, min(m, n)) - U = ROCMatrix{$matrix_elty}(undef, (m, min(m, n))) - ldu = m - @assert stride(U, 2) == ldu - V = ROCMatrix{$matrix_elty}(undef, (min(m, n), n)) - ldv = min(m, n) - @assert stride(V, 2) == ldv - dev_info = ROCVector{Cint}(undef, 1) + ldv = min(m, n) + k=ldv + ldu = m + + V = if jobvt === 'A' + ROCMatrix{$matrix_elty}(undef, (n, n)) + elseif jobvt === 'S' + ROCMatrix{$matrix_elty}(undef, (n, k)) + elseif jobvt === 'N' || jobvt === 'O' + C_NULL + else + error("jobvt must be one of 'A', 'S', 'O', or 'N'") + end + U = if jobu === 'A' + ROCMatrix{$matrix_elty}(undef, (m, m)) + elseif jobu === 'S' + ROCMatrix{$matrix_elty}(undef, (m, k)) + elseif jobu === 'N' || jobu === 'O' + C_NULL + else + error("jobu must be one of 'A', 'S', 'O', or 'N'") + end + $fname( rocBLAS.handle(), - rocblas_svect_singular, - rocblas_svect_singular, + jobu, + jobvt, m, n, A, lda, abstol, dev_residual, @@ -420,7 +435,7 @@ for (fname, matrix_elty, vector_elty) in ( info = AMDGPU.@allowscalar dev_info[1] AMDGPU.unsafe_free!(dev_info) - U, S, V', residual, n_sweeps, info + U, S, (jobvt === 'N' || jobvt === 'O') ? V : V', residual, n_sweeps, info end end end