Skip to content

Commit

Permalink
fix: fix views with boolean mask
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 17, 2024
1 parent a64ff30 commit c44a776
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,13 @@ function Base.view(A::AbstractVectorOfArray{T,N,<:AbstractVector{T}}, I::Vararg{
J = map(i->Base.unalias(A,i), to_indices(A, Base.tail(I)))
end
@boundscheck checkbounds(A, J...)
SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...))
SubArray(A, J)

Check warning on line 531 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L531

Added line #L531 was not covered by tests
end
function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M}
@inline
J = map(i->Base.unalias(A,i), to_indices(A, I))
@boundscheck checkbounds(A, J...)
SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...))
SubArray(A, J)

Check warning on line 537 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L537

Added line #L537 was not covered by tests
end
function Base.SubArray(parent::AbstractVectorOfArray, indices::Tuple)
@inline
Expand Down
4 changes: 3 additions & 1 deletion test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ arrvb = Array(testvb)
# view
testvc = VectorOfArray([rand(1:10, 3, 3) for _ in 1:3])
arrvc = Array(testvc)
for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (:, :, :)]
for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (:, :, :), (1:2, 1:2, Bool[1, 0, 1]), (1:2, Bool[1, 0, 1], 1:2), (Bool[1, 0, 1], 1:2, 1:2)]
arr_view = view(arrvc, idxs...)
voa_view = view(testvc, idxs...)
@test size(arr_view) == size(voa_view)
Expand All @@ -93,11 +93,13 @@ end

testvc = VectorOfArray(collect(1:10))
arrvc = Array(testvc)
bool_idx = rand(Bool, 10)
for (voaidx, arridx) in [
((:,), (:,)),
((3:5,), (3:5,)),
((:, 3:5), (3:5,)),
((1, 3:5), (3:5,)),
((:, bool_idx), (bool_idx,))
]
arr_view = view(arrvc, arridx...)
voa_view = view(testvc, voaidx...)
Expand Down

0 comments on commit c44a776

Please sign in to comment.