Skip to content

Commit

Permalink
fix: ArrayPartition arithmetic type-stability
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 8, 2024
1 parent ef55922 commit efac89a
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ for op in (:*, :/)
end

function Base.:*(A::Number, B::ArrayPartition)
ArrayPartition(map(y -> Base.broadcast(*, A, y), B.x))
ArrayPartition(map(y -> A .* y, B.x))

Check warning on line 156 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L156

Added line #L156 was not covered by tests
end

function Base.:\(A::Number, B::ArrayPartition)
ArrayPartition(map(y -> Base.broadcast(/, y, A), B.x))
B / A

Check warning on line 160 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L160

Added line #L160 was not covered by tests
end

Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
Expand Down Expand Up @@ -284,7 +284,7 @@ recursive_eltype(A::ArrayPartition) = recursive_eltype(first(A.x))
Base.iterate(A::ArrayPartition) = iterate(Chain(A.x))
Base.iterate(A::ArrayPartition, state) = iterate(Chain(A.x), state)

Base.length(A::ArrayPartition) = sum((length(x) for x in A.x))
Base.length(A::ArrayPartition) = sum(broadcast(length, A.x))

Check warning on line 287 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L287

Added line #L287 was not covered by tests
Base.size(A::ArrayPartition) = (length(A),)

# redefine first and last to avoid slow and not type-stable indexing
Expand Down Expand Up @@ -323,21 +323,13 @@ function Broadcast.BroadcastStyle(::ArrayPartitionStyle,
Broadcast.DefaultArrayStyle{N}()
end

combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
@inline function combine_styles(args::Tuple{Any})
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
end
@inline function combine_styles(args::Tuple{Any, Any})
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
Broadcast.BroadcastStyle(args[2]))
end
@inline function combine_styles(args::Tuple)
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
combine_styles(Base.tail(args)))
@generated function combine_styles(t)
@show t
return :($(reduce(Broadcast.result_style, Broadcast.BroadcastStyle.(t.parameters[1].parameters))))

Check warning on line 328 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L326-L328

Added lines #L326 - L328 were not covered by tests
end

function Broadcast.BroadcastStyle(::Type{ArrayPartition{T, S}}) where {T, S}
Style = combine_styles((S.parameters...,))
Style = combine_styles(S)

Check warning on line 332 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L332

Added line #L332 was not covered by tests
ArrayPartitionStyle(Style)
end

Expand Down

0 comments on commit efac89a

Please sign in to comment.