Skip to content

Commit f8483b5

Browse files
committed
support map with GPU arrays
1 parent 0b99716 commit f8483b5

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/structarray.jl

+2
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
495495
toplevel && print(io, " with eltype ", T)
496496
end
497497

498+
Base.map(f, s::StructArray) = f.(s)
499+
498500
# broadcast
499501
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
500502
using Base.Broadcast: combine_styles

test/runtests.jl

+9
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,15 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
13771377
@test @inferred(bcmul2(sa)) isa StructArray
13781378
@test backend(bcmul2(sa)) === backend(sa)
13791379
@test (sa .+= 1) === sa
1380+
1381+
@test_broken collect(sa)
1382+
1383+
a2 = map(x -> real(x) + 1, sa)
1384+
@test a2::JLArray == sa.re .+ 1
1385+
sa2 = map(x -> x + 1, sa)
1386+
@test sa2.re::JLArray == sa.re .+ 1
1387+
sa3 = map(x -> (a=x + 1, b=x.re + x.im), sa)
1388+
@test sa3.b::JLArray == sa.re .+ sa.im
13801389
end
13811390

13821391
@testset "StructSparseArray" begin

0 commit comments

Comments
 (0)