Skip to content

Commit 8b27dfa

Browse files
authored
Merge pull request #4 from FelixBenning/1-lazy_product
Lazy Product
2 parents 9ffd33a + 76f4301 commit 8b27dfa

File tree

5 files changed

+8
-48
lines changed

5 files changed

+8
-48
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
99
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
12+
ProductArrays = "39a08dd0-b4b7-4086-afbb-eb1796240846"
1213
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1314

1415
[compat]

src/multiOutput.jl

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,7 @@
11

22
using MappedArrays: mappedarray, ReadonlyMappedArray
33

4-
function ensure_all_linear_indexed(vecs::T) where {T<:Tuple}
5-
linear_indexed = ntuple(
6-
n -> Base.IndexStyle(fieldtype(T,n)) === IndexLinear(),
7-
Base._counttuple(T)
8-
)
9-
all(linear_indexed) || throw(ArgumentError(
10-
"$(vecs[findfirst(x->!x, linear_indexed)]) cannot be linearly accessed. All inputs need to implement `Base.getindex(::T, ::Int)`"
11-
))
12-
end
13-
14-
struct ProductArray
15-
arrays
16-
end
17-
18-
"""
19-
lazy_product(vectors...)
20-
21-
The output is a lazy form of
22-
```julia
23-
collect(Iterators.product(vectors...))
24-
```
25-
i.e. it is an AbstractArray in contrast to `Iterators.product(vectors...)`. So
26-
is accessible with `getindex` and gets default Array implementations for free.
27-
In particular it can be passed to `Base.PermutedDimsArray`` for lazy permutation
28-
and `vec()` to obtain a lazy `Base.ReshapedArray`.
29-
"""
30-
function lazy_product(vectors...)
31-
ensure_all_linear_indexed(vectors)
32-
indices = CartesianIndices(ntuple(n -> axes(vec(vectors[n]), 1), length(vectors)))
33-
return mappedarray(indices) do idx
34-
return ntuple(n -> vec(vectors[n])[idx[n]], length(vectors))
35-
end
36-
end
4+
using ProductArrays: productArray
375

386
"""
397
lazy_flatten(vectors...)
@@ -49,19 +17,19 @@ and `vec()` to obtain a lazy `Base.ReshapedArray`.
4917
"""
5018
function lazy_flatten(vectors...)
5119
ensure_all_linear_indexed(vectors)
52-
lengths = cumsum(length.(vectors))
20+
lengths = cumsum(length.(vectors))
5321
return mappedarray(1:lengths[end]) do idx
5422
# this is not efficient for iteration - maybe go with LazyArrays.jl -> Vcat instead.
5523
v_idx = searchsortedfirst(lengths, idx)
56-
return vectors[v_idx][idx - get(lengths, v_idx-1, 0)]
24+
return vectors[v_idx][idx-get(lengths, v_idx - 1, 0)]
5725
end
5826
end
5927

6028

6129
function multi_out_byFeatures(positions, out_dims)
62-
return vec(PermutedDimsArray(lazy_product(positions, 1:out_dims), (2, 1)))
30+
return vec(PermutedDimsArray(productArray(positions, 1:out_dims), (2, 1)))
6331
end
6432
function multi_out_byOutput(positions, out_dims)
65-
return vec(lazy_product(positions, 1:out_dims))
33+
return vec(productArray(positions, 1:out_dims))
6634
end
6735

src/partial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ end
5555
const DiffPt{T} = Tuple{T,Partial}
5656

5757
gradient(dim::Integer) = mappedarray(partial, Base.OneTo(dim))
58-
hessian(dim::Integer) = mappedarray(partial, lazy_product(Base.OneTo(dim), Base.OneTo(dim)))
59-
fullderivative(order::Integer,dim::Integer) = mappedarray(partial, lazy_product(ntuple(_->Base.OneTo(dim), order)...))
58+
hessian(dim::Integer) = mappedarray(partial, productArray(Base.OneTo(dim), Base.OneTo(dim)))
59+
fullderivative(order::Integer,dim::Integer) = mappedarray(partial, productArray(ntuple(_->Base.OneTo(dim), order)...))
6060

6161
# idea: lazy mappings can be undone (extract original range -> towards a specialization speedup of broadcasting over multiple derivatives using backwardsdiff)
6262
const MappedPartialVec{T} = ReadonlyMappedArray{Partial{1,Int},1,T,typeof(partial)}

test/multiOutput.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using Test
66
List of Testfiles without extension. `\$(test).jl"` should be a file for every test in AVAILABLE_TESTS
77
"""
88
const AVAILABLE_TESTS = [
9-
"multiOutput",
109
"diffKernel",
1110
]
1211

0 commit comments

Comments
 (0)