Skip to content

Commit 3e3171d

Browse files
smartalecHAlec Hammondstevengj
authored
Modify findfirst to return index rather than shape (#39)
* modify findfirst to return index rather than shape * use eachindex instead of 1:size * Use `Integer` not `Int` Co-authored-by: Steven G. Johnson <[email protected]> --------- Co-authored-by: Alec Hammond <[email protected]> Co-authored-by: Steven G. Johnson <[email protected]>
1 parent 3146db1 commit 3e3171d

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

src/util/kdtree.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ not shapes of nonzero size.)
1616
"""
1717
mutable struct KDTree{K,S<:Shape{K}}
1818
s::Vector{S}
19+
s_index::Vector{Int}
1920
ix::Int
2021
x::Float64
2122
left::KDTree{K,S} # shapes ≤ x in coordinate ix
2223
right::KDTree{K,S} # shapes > x in coordinate ix
23-
KDTree{K,S}(s::AbstractVector{S}) where {K,S<:Shape{K}} = new(s, 0)
24+
KDTree{K,S}(s::AbstractVector{S}) where {K,S<:Shape{K}} = new(s, collect(eachindex(s)), 0)
25+
KDTree{K,S}(s::AbstractVector{S},s_index::Vector{<:Int}) where {K,S<:Shape{K}} = new(s, s_index, 0)
2426
function KDTree{K,S}(ix::Integer, x::Real, left::KDTree{K,S}, right::KDTree{K,S}) where {K,S<:Shape{K}}
2527
1 ix K || throw(BoundsError())
26-
new(S[], ix, x, left, right)
28+
new(S[], Int[], ix, x, left, right)
2729
end
2830
end
2931

@@ -38,8 +40,15 @@ Construct a K-D tree (`KDTree`) representation of a list of
3840
When searching the tree, shapes that appear earlier in `s`
3941
take precedence over shapes that appear later.
4042
"""
43+
4144
function KDTree(s::AbstractVector{S}) where {K,S<:Shape{K}}
42-
(length(s) 4 || K == 0) && return KDTree{K,S}(s)
45+
# If no list of indicies is provided, simply enumerate by the number of
46+
# shapes in `s`.
47+
return KDTree(s,collect(eachindex(s)))
48+
end
49+
50+
function KDTree(s::AbstractVector{S}, s_index::AbstractVector{<:Integer}) where {K,S<:Shape{K}}
51+
(length(s) 4 || K == 0) && return KDTree{K,S}(s, s_index)
4352

4453
# figure out the best dimension ix to divide over,
4554
# the dividing plane x, and the number (nl,nr) of
@@ -61,22 +70,26 @@ function KDTree(s::AbstractVector{S}) where {K,S<:Shape{K}}
6170
end
6271

6372
# don't bother subdividing if it doesn't reduce the # of shapes much
64-
4*max(nl,nr) > 3*length(s) && return KDTree{K,S}(s)
73+
4*max(nl,nr) > 3*length(s) && return KDTree{K,S}(s,s_index)
6574

6675
# create the arrays of shapes in each subtree
6776
sl = Vector{S}(undef, nl)
77+
sl_idx = Vector{Int}(undef, nl)
6878
sr = Vector{S}(undef, nr)
79+
sr_idx = Vector{Int}(undef, nr)
6980
il = ir = 0
7081
for k in eachindex(s)
7182
if b[k][1][ix] x
7283
sl[il += 1] = s[k]
84+
sl_idx[il] = s_index[k]
7385
end
7486
if b[k][2][ix] > x
7587
sr[ir += 1] = s[k]
88+
sr_idx[ir] = s_index[k]
7689
end
7790
end
7891

79-
return KDTree{K,S}(ix, x, KDTree(sl), KDTree(sr))
92+
return KDTree{K,S}(ix, x, KDTree(sl,sl_idx), KDTree(sr,sr_idx))
8093
end
8194

8295
depth(kd::KDTree) = kd.ix == 0 ? 0 : max(depth(kd.left), depth(kd.right)) + 1
@@ -104,7 +117,7 @@ function Base.findfirst(p::SVector{N}, s::Vector{S}) where {N,S<:Shape{N}}
104117
for i in eachindex(s)
105118
b = bounds(s[i])
106119
if all(b[1] .< p .< b[2]) && p s[i] # check if p is within bounding box is faster
107-
return s[i]
120+
return i
108121
end
109122
end
110123
return nothing
@@ -118,7 +131,12 @@ function Base.findfirst(p::SVector{N}, kd::KDTree{N}) where {N}
118131
return findfirst(p, kd.right)
119132
end
120133
else
121-
return findfirst(p, kd.s)
134+
idx = findfirst(p, kd.s)
135+
if isnothing(idx)
136+
return idx
137+
else
138+
return kd.s_index[idx]
139+
end
122140
end
123141
end
124142

test/kdtree.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
s = Shape2[Ball([i,0], 1) for i in 0:20]
88
kd = KDTree(s)
99
@test GeometryPrimitives.depth(kd) == 3
10-
@test findfirst([10.1,0], kd).c[1] == 10
10+
@test s[findfirst([10.1,0], kd)].c[1] == 10
1111
@test findfirst([10.1,1], kd) == nothing
1212
@test checktree(kd, s)
1313

0 commit comments

Comments
 (0)