diff --git a/Project.toml b/Project.toml index 4a4b9322..8e5a9c38 100644 --- a/Project.toml +++ b/Project.toml @@ -60,7 +60,7 @@ SafeTestsets = "0.1" SparseArrays = "1.10" StaticArrays = "1.6" StaticArraysCore = "1.4" -Statistics = "1.10" +Statistics = "1.10, 1.11" StructArrays = "0.6.11, 0.7" SymbolicIndexingInterface = "0.3.25" Tables = "1.11" diff --git a/src/array_partition.jl b/src/array_partition.jl index 4ee1a9ea..b0325fe0 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -209,13 +209,19 @@ function Base.copyto!(A::ArrayPartition, src::ArrayPartition) A end +function Base.fill!(A::ArrayPartition, x) + unrolled_foreach!(A.x) do x_ + fill!(x_, x) + end + A +end + function recursivefill!(b::ArrayPartition, a::T2) where {T2 <: Union{Number, Bool}} unrolled_foreach!(b.x) do x fill!(x, a) end end - ## indexing # Interface for the linear indexing. This is just a view of the underlying nested structure diff --git a/test/gpu/arraypartition_gpu.jl b/test/gpu/arraypartition_gpu.jl index c9a87dc8..3b335855 100644 --- a/test/gpu/arraypartition_gpu.jl +++ b/test/gpu/arraypartition_gpu.jl @@ -14,3 +14,7 @@ mask = pA .> 0 # Test recursive filling is done using GPU kernels and not scalar indexing RecursiveArrayTools.recursivefill!(pA, true) @test all(pA .== true) + +# Test that regular filling is done using GPU kernels and not scalar indexing +fill!(pA, false) +@test all(pA .== false)