Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 9032718

Browse files
authored
Try #695:
2 parents 2258a24 + 5970046 commit 9032718

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

src/solver/wrappers.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,3 +823,61 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
823823
end
824824
end
825825
end
826+
827+
for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBatched_bufferSize, :cusolverDnSsyevjBatched, :Float32, :Float32),
828+
(:syevjBatched!, :cusolverDnDsyevjBatched_bufferSize, :cusolverDnDsyevjBatched, :Float64, :Float64),
829+
(:heevjBatched!, :cusolverDnCheevjBatched_bufferSize, :cusolverDnCheevjBatched, :ComplexF32, :Float32),
830+
(:heevjBatched!, :cusolverDnZheevjBatched_bufferSize, :cusolverDnZheevjBatched, :ComplexF64, :Float64)
831+
)
832+
@eval begin
833+
function $jname(jobz::Char,
834+
uplo::Char,
835+
A::CuArray{$elty};
836+
tol::$relty=eps($relty),
837+
max_sweeps::Int=100)
838+
839+
# Set up information for the solver arguments
840+
cuuplo = cublasfill(uplo)
841+
cujobz = cusolverjob(jobz)
842+
n = checksquare(A)
843+
lda = max(1, stride(A, 2))
844+
batchSize = size(A,3)
845+
W = CuArray{$relty}(undef, n,batchSize)
846+
params = Ref{syevjInfo_t}(C_NULL)
847+
devinfo = CuArray{Cint}(undef, batchSize)
848+
849+
# Initialize the solver parameters
850+
cusolverDnCreateSyevjInfo(params)
851+
cusolverDnXsyevjSetTolerance(params[], tol)
852+
cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps)
853+
854+
# Calculate the workspace size
855+
lwork = @argout(CUSOLVER.$bname(dense_handle(), cujobz, cuuplo, n,
856+
A, lda, W, out(Ref{Cint}(0)), params, batchSize))[]
857+
858+
# Run the solver
859+
@workspace eltyp=$elty size=lwork work->begin
860+
$fname(dense_handle(), cujobz, cuuplo, n, A, lda, W, work,
861+
lwork, devinfo, params[], batchSize)
862+
end
863+
864+
# Copy the solver info and delete the device memory
865+
info = @allowscalar collect(devinfo)
866+
unsafe_free!(devinfo)
867+
868+
# Double check the solver's exit status
869+
for i = 1:batchSize
870+
if info[i] < 0
871+
throw(ArgumentError("The $(info)th parameter of the $(i)th solver is wrong"))
872+
end
873+
end
874+
875+
# Return eigenvalues (in W) and possibly eigenvectors (in A)
876+
if jobz == 'N'
877+
return W
878+
elseif jobz == 'V'
879+
return W, A
880+
end
881+
end
882+
end
883+
end

test/solver.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,60 @@ k = 1
279279
@test Eig.values h_W
280280
end
281281

282+
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
283+
@testset "syevjBatched!" begin
284+
# Generate a random symmetric/hermitian matrix
285+
A = rand(elty, m,m,n)
286+
A += permutedims(A, (2,1,3))
287+
d_A = CuArray(A)
288+
289+
# Run the solver
290+
local d_W, d_V
291+
if( elty <: Complex )
292+
d_W, d_V = CUSOLVER.heevjBatched!('V','U', d_A)
293+
else
294+
d_W, d_V = CUSOLVER.syevjBatched!('V','U', d_A)
295+
end
296+
297+
# Pull it back to hardware
298+
h_W = collect(d_W)
299+
h_V = collect(d_V)
300+
301+
# Use non-GPU blas to estimate the eigenvalues as well
302+
for i = 1:n
303+
# Get our eigenvalues
304+
Eig = eigen(LinearAlgebra.Hermitian(A[:,:,i]))
305+
306+
# Compare to the actual ones
307+
@test Eig.values h_W[:,i]
308+
@test abs.(Eig.vectors'*h_V[:,:,i]) I
309+
end
310+
311+
# Do it all again, but with the option to not compute eigenvectors
312+
d_A = CuArray(A)
313+
314+
# Run the solver
315+
local d_W
316+
if( elty <: Complex )
317+
d_W = CUSOLVER.heevjBatched!('N','U', d_A)
318+
else
319+
d_W = CUSOLVER.syevjBatched!('N','U', d_A)
320+
end
321+
322+
# Pull it back to hardware
323+
h_W = collect(d_W)
324+
325+
# Use non-GPU blas to estimate the eigenvalues as well
326+
for i = 1:n
327+
# Get the reference results
328+
Eig = eigen(LinearAlgebra.Hermitian(A[:,:,i]))
329+
330+
# Compare to the actual ones
331+
@test Eig.values h_W[:,i]
332+
end
333+
end
334+
end
335+
282336
@testset "svd with $method method" for
283337
method in (CUSOLVER.QRAlgorithm, CUSOLVER.JacobiAlgorithm),
284338
(_m, _n) in ((m, n), (n, m))

0 commit comments

Comments
 (0)